From 2ecabde500dc64cf87e362f5b8edb1fdb139b246 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Thu, 23 Apr 2020 06:16:44 +0400 Subject: [PATCH] DL4J NWC support for RNNs (#379) * merge conf * merge conf * conf fix * NWC initial * revert pom.xml * revert pom.xml * default NCW * bidirectional+some tests * RNNOutputLayer, RNNLossLayer, Graves + tests * rnn tests * LastTimeStep + tests * masking + tests * graves, rnnoutput, rnnloss * nwc timeseries reverse * more tests * bi-gravelstm test * fixes * rnn df tests basic * bug fix: cudnn fallback * bug fix * misc * gravelstm tests * preprocessor fixes * TimeDistributed * more tests * RnnLossLayer builder def val * copyright headers * Remove debug println Signed-off-by: Alex Black * Small fix + test naming Signed-off-by: Alex Black * Parameterized test name Signed-off-by: Alex Black * fix LastTimeStep masked * Fix MaskZero mask datatype issue Signed-off-by: Alex Black * rem println * javadoc * Fixes Signed-off-by: Alex Black Co-authored-by: Alex Black --- .../gradientcheck/RnnGradientChecks.java | 2 +- .../deeplearning4j/nn/dtypes/DTypeTests.java | 2 +- .../layers/recurrent/BidirectionalTest.java | 143 ++++--- .../GravesBidirectionalLSTMTest.java | 132 ++++-- .../layers/recurrent/MaskZeroLayerTest.java | 24 +- .../layers/recurrent/RnnDataFormatTests.java | 394 ++++++++++++++++++ .../recurrent/TestLastTimeStepLayer.java | 55 ++- .../nn/layers/recurrent/TestRnnLayers.java | 59 ++- .../nn/layers/recurrent/TestSimpleRnn.java | 43 +- .../layers/recurrent/TestTimeDistributed.java | 41 +- .../convolutional/KerasConvolution1D.java | 5 +- .../keras/layers/recurrent/KerasLSTM.java | 10 +- .../layers/recurrent/KerasSimpleRnn.java | 6 +- .../layers/wrappers/KerasBidirectional.java | 2 +- .../nn/api/layers/RecurrentLayer.java | 2 + .../org/deeplearning4j/nn/conf/RNNFormat.java | 29 ++ .../nn/conf/inputs/InputType.java | 46 +- .../nn/conf/layers/BaseRecurrentLayer.java | 19 +- .../nn/conf/layers/Convolution1DLayer.java | 3 +- .../nn/conf/layers/FeedForwardLayer.java | 2 +- .../nn/conf/layers/InputTypeUtil.java | 7 +- .../layers/LearnedSelfAttentionLayer.java | 3 +- .../nn/conf/layers/LocallyConnected1D.java | 3 +- .../conf/layers/RecurrentAttentionLayer.java | 3 +- .../nn/conf/layers/RnnLossLayer.java | 17 +- .../nn/conf/layers/RnnOutputLayer.java | 18 +- .../nn/conf/layers/SelfAttentionLayer.java | 3 +- .../nn/conf/layers/Subsampling1DLayer.java | 3 +- .../nn/conf/layers/ZeroPadding1DLayer.java | 3 +- .../conf/layers/recurrent/Bidirectional.java | 8 +- .../layers/recurrent/TimeDistributed.java | 16 +- .../preprocessor/CnnToRnnPreProcessor.java | 28 +- .../FeedForwardToRnnPreProcessor.java | 26 +- .../preprocessor/RnnToCnnPreProcessor.java | 26 +- .../RnnToFeedForwardPreProcessor.java | 27 +- .../layers/recurrent/BaseRecurrentLayer.java | 20 +- .../layers/recurrent/BidirectionalLayer.java | 49 ++- .../recurrent/GravesBidirectionalLSTM.java | 23 +- .../nn/layers/recurrent/GravesLSTM.java | 20 +- .../nn/layers/recurrent/LSTM.java | 16 +- .../nn/layers/recurrent/LSTMHelpers.java | 13 +- .../layers/recurrent/LastTimeStepLayer.java | 50 ++- .../nn/layers/recurrent/MaskZeroLayer.java | 9 +- .../nn/layers/recurrent/RnnLossLayer.java | 35 +- .../nn/layers/recurrent/RnnOutputLayer.java | 27 +- .../nn/layers/recurrent/SimpleRnn.java | 17 +- .../recurrent/TimeDistributedLayer.java | 11 +- .../deeplearning4j/util/TimeSeriesUtils.java | 34 ++ 48 files changed, 1256 insertions(+), 278 deletions(-) create mode 100644 deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java create mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/RNNFormat.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java index b81849d53..e356cce1d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java @@ -338,7 +338,7 @@ public class RnnGradientChecks extends BaseDL4JTest { .weightInit(WeightInit.XAVIER) .list() .layer(new LSTM.Builder().nOut(layerSize).build()) - .layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build(), 2)) + .layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build())) .layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .setInputType(InputType.recurrent(nIn)) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java index 6831af10b..437141d59 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java @@ -819,7 +819,7 @@ public class DTypeTests extends BaseDL4JTest { .layer(new DenseLayer.Builder().nOut(5).build()) .layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()) .layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())) - .layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build(), 2)) + .layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build())) .layer(new SimpleRnn.Builder().nIn(5).nOut(5).build()) .layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build()) .layer(secondLast) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java index f7a4c087f..489687679 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java @@ -24,10 +24,7 @@ import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver; import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator; import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition; import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer; -import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.WorkspaceMode; +import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM; import org.deeplearning4j.nn.conf.layers.GravesLSTM; @@ -45,6 +42,8 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.TimeSeriesUtils; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -61,12 +60,22 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import static org.deeplearning4j.nn.conf.RNNFormat.NCW; import static org.junit.Assert.assertEquals; @Slf4j +@RunWith(Parameterized.class) public class BidirectionalTest extends BaseDL4JTest { + private RNNFormat rnnDataFormat; + public BidirectionalTest(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + } + @Parameterized.Parameters + public static Object[] params(){ + return RNNFormat.values(); + } @Test public void compareImplementations(){ for(WorkspaceMode wsm : WorkspaceMode.values()) { @@ -82,9 +91,9 @@ public class BidirectionalTest extends BaseDL4JTest { .inferenceWorkspaceMode(wsm) .updater(new Adam()) .list() - .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build())) - .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build())) - .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) + .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) + .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) + .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat) .nIn(10).nOut(10).build()) .build(); @@ -95,9 +104,9 @@ public class BidirectionalTest extends BaseDL4JTest { .inferenceWorkspaceMode(wsm) .updater(new Adam()) .list() - .layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build()) - .layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build()) - .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) + .layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) + .layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) + .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat) .nIn(10).nOut(10).build()) .build(); @@ -116,15 +125,24 @@ public class BidirectionalTest extends BaseDL4JTest { net2.setParams(net1.params()); //Assuming exact same layout here... - INDArray in = Nd4j.rand(new int[]{3, 10, 5}); + INDArray in; + if (rnnDataFormat == NCW){ + in = Nd4j.rand(new int[]{3, 10, 5}); + }else{ + in = Nd4j.rand(new int[]{3, 5, 10}); + } INDArray out1 = net1.output(in); INDArray out2 = net2.output(in); assertEquals(out1, out2); - INDArray labels = Nd4j.rand(new int[]{3, 10, 5}); - + INDArray labels; + if (rnnDataFormat == NCW){ + labels = Nd4j.rand(new int[]{3, 10, 5}); + }else{ + labels = Nd4j.rand(new int[]{3, 5, 10}); + } net1.setInput(in); net1.setLabels(labels); @@ -276,17 +294,22 @@ public class BidirectionalTest extends BaseDL4JTest { .inferenceWorkspaceMode(wsm) .updater(new Adam()) .list() - .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build())) - .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build())) + .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) + .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) - .nIn(10).nOut(10).build()) + .nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) .build(); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); - INDArray in = Nd4j.rand(new int[]{3, 10, 5}); - INDArray labels = Nd4j.rand(new int[]{3, 10, 5}); + INDArray in; + INDArray labels; + + long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 5} : new long[]{3, 5, 10}; + + in = Nd4j.rand(inshape); + labels = Nd4j.rand(inshape); net1.fit(in, labels); @@ -300,8 +323,8 @@ public class BidirectionalTest extends BaseDL4JTest { MultiLayerNetwork net2 = ModelSerializer.restoreMultiLayerNetwork(new ByteArrayInputStream(bytes), true); - in = Nd4j.rand(new int[]{3, 10, 5}); - labels = Nd4j.rand(new int[]{3, 10, 5}); + in = Nd4j.rand(inshape); + labels = Nd4j.rand(inshape); INDArray out1 = net1.output(in); INDArray out2 = net2.output(in); @@ -338,18 +361,18 @@ public class BidirectionalTest extends BaseDL4JTest { .updater(new Adam()) .graphBuilder() .addInputs("in") - .layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "in") - .layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "0") - .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) + .layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in") + .layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "0") + .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat) .nIn(10).nOut(10).build(), "1") .setOutputs("2") .build(); ComputationGraph net1 = new ComputationGraph(conf1); net1.init(); - - INDArray in = Nd4j.rand(new int[]{3, 10, 5}); - INDArray labels = Nd4j.rand(new int[]{3, 10, 5}); + long[] inshape = (rnnDataFormat == NCW)? new long[]{3, 10, 5}: new long[]{3, 5, 10}; + INDArray in = Nd4j.rand(inshape); + INDArray labels = Nd4j.rand(inshape); net1.fit(new DataSet(in, labels)); @@ -363,8 +386,8 @@ public class BidirectionalTest extends BaseDL4JTest { ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true); - in = Nd4j.rand(new int[]{3, 10, 5}); - labels = Nd4j.rand(new int[]{3, 10, 5}); + in = Nd4j.rand(inshape); + labels = Nd4j.rand(inshape); INDArray out1 = net1.outputSingle(in); INDArray out2 = net2.outputSingle(in); @@ -394,8 +417,8 @@ public class BidirectionalTest extends BaseDL4JTest { Bidirectional.Mode[] modes = new Bidirectional.Mode[]{Bidirectional.Mode.CONCAT, Bidirectional.Mode.ADD, Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL}; - - INDArray in = Nd4j.rand(new int[]{3, 10, 6}); + long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10}; + INDArray in = Nd4j.rand(inshape); for (Bidirectional.Mode m : modes) { MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() @@ -406,7 +429,7 @@ public class BidirectionalTest extends BaseDL4JTest { .inferenceWorkspaceMode(wsm) .updater(new Adam()) .list() - .layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build())) + .layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) .build(); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); @@ -418,7 +441,7 @@ public class BidirectionalTest extends BaseDL4JTest { .weightInit(WeightInit.XAVIER) .updater(new Adam()) .list() - .layer(new SimpleRnn.Builder().nIn(10).nOut(10).build()) + .layer(new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) .build(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2.clone()); @@ -434,11 +457,10 @@ public class BidirectionalTest extends BaseDL4JTest { net3.setParam("0_RW", net1.getParam("0_bRW")); net3.setParam("0_b", net1.getParam("0_bb")); - INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); - + INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat); INDArray out1 = net1.output(in); INDArray out2 = net2.output(in); - INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); + INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat); INDArray outExp; switch (m) { @@ -452,7 +474,7 @@ public class BidirectionalTest extends BaseDL4JTest { outExp = out2.add(out3).muli(0.5); break; case CONCAT: - outExp = Nd4j.concat(1, out2, out3); + outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3); break; default: throw new RuntimeException(); @@ -464,25 +486,25 @@ public class BidirectionalTest extends BaseDL4JTest { //Check gradients: if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) { - INDArray eps = Nd4j.rand(new int[]{3, 10, 6}); + INDArray eps = Nd4j.rand(inshape); INDArray eps1; if (m == Bidirectional.Mode.CONCAT) { - eps1 = Nd4j.concat(1, eps, eps); + eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps); } else { eps1 = eps; } net1.setInput(in); net2.setInput(in); - net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)); + net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat)); net1.feedForward(true, false); net2.feedForward(true, false); net3.feedForward(true, false); Pair p1 = net1.backpropGradient(eps1, LayerWorkspaceMgr.noWorkspaces()); Pair p2 = net2.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); - Pair p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT), LayerWorkspaceMgr.noWorkspaces()); + Pair p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat), LayerWorkspaceMgr.noWorkspaces()); Gradient g1 = p1.getFirst(); Gradient g2 = p2.getFirst(); Gradient g3 = p3.getFirst(); @@ -520,7 +542,9 @@ public class BidirectionalTest extends BaseDL4JTest { Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL}; - INDArray in = Nd4j.rand(new int[]{3, 10, 6}); + long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10}; + INDArray in = Nd4j.rand(inshape); + for (Bidirectional.Mode m : modes) { ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder() @@ -532,7 +556,7 @@ public class BidirectionalTest extends BaseDL4JTest { .updater(new Adam()) .graphBuilder() .addInputs("in") - .layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build()), "in") + .layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in") .setOutputs("0") .build(); @@ -546,7 +570,7 @@ public class BidirectionalTest extends BaseDL4JTest { .updater(new Adam()) .graphBuilder() .addInputs("in") - .layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "in") + .layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build(), "in") .setOutputs("0") .build(); @@ -566,9 +590,20 @@ public class BidirectionalTest extends BaseDL4JTest { INDArray out1 = net1.outputSingle(in); INDArray out2 = net2.outputSingle(in); - INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.outputSingle( - TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)), - LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); + INDArray out3; + INDArray inReverse; + if (rnnDataFormat == RNNFormat.NWC){ + inReverse = TimeSeriesUtils.reverseTimeSeries(in.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1); + out3 = net3.outputSingle(inReverse); + out3 = TimeSeriesUtils.reverseTimeSeries(out3.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1); + + } + else{ + inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); + out3 = net3.outputSingle(inReverse); + out3 = TimeSeriesUtils.reverseTimeSeries(out3, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); + + } INDArray outExp; switch (m) { @@ -582,7 +617,9 @@ public class BidirectionalTest extends BaseDL4JTest { outExp = out2.add(out3).muli(0.5); break; case CONCAT: - outExp = Nd4j.concat(1, out2, out3); + System.out.println(out2.shapeInfoToString()); + System.out.println(out3.shapeInfoToString()); + outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3); break; default: throw new RuntimeException(); @@ -594,22 +631,26 @@ public class BidirectionalTest extends BaseDL4JTest { //Check gradients: if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) { - INDArray eps = Nd4j.rand(new int[]{3, 10, 6}); + INDArray eps = Nd4j.rand(inshape); INDArray eps1; if (m == Bidirectional.Mode.CONCAT) { - eps1 = Nd4j.concat(1, eps, eps); + eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps); } else { eps1 = eps; } + INDArray epsReversed = (rnnDataFormat == NCW)? + TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT): + TimeSeriesUtils.reverseTimeSeries(eps.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT) + .permute(0, 2, 1); net1.outputSingle(true, false, in); net2.outputSingle(true, false, in); - net3.outputSingle(true, false, TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)); + net3.outputSingle(true, false, inReverse); Gradient g1 = net1.backpropGradient(eps1); Gradient g2 = net2.backpropGradient(eps); - Gradient g3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)); + Gradient g3 = net3.backpropGradient(epsReversed); for (boolean updates : new boolean[]{false, true}) { if (updates) { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java index 751b6f6bf..441267e86 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -31,6 +32,8 @@ import org.deeplearning4j.nn.params.GravesLSTMParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.api.ndarray.INDArray; @@ -42,10 +45,18 @@ import org.nd4j.linalg.primitives.Pair; import static org.junit.Assert.*; - +@RunWith(Parameterized.class) public class GravesBidirectionalLSTMTest extends BaseDL4JTest { private double score = 0.0; + private RNNFormat rnnDataFormat; + public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + } + @Parameterized.Parameters + public static Object[] params(){ + return RNNFormat.values(); + } @Test public void testBidirectionalLSTMGravesForwardBasic() { //Very basic test of forward prop. of LSTM layer with a time series. @@ -55,7 +66,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) - .nOut(nHiddenUnits).activation(Activation.TANH).build()) + .nOut(nHiddenUnits).dataFormat(rnnDataFormat).activation(Activation.TANH).build()) .build(); val numParams = conf.getLayer().initializer().numParams(conf); @@ -65,22 +76,41 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { //Data: has shape [miniBatchSize,nIn,timeSeriesLength]; //Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; + if (rnnDataFormat == RNNFormat.NCW){ + final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1); + final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1}); - final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1); - final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1}); + final INDArray dataMultiExampleLength1 = Nd4j.ones(10, nIn, 1); + final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations2.shape(), new long[] {10, nHiddenUnits, 1}); - final INDArray dataMultiExampleLength1 = Nd4j.ones(10, nIn, 1); - final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations2.shape(), new long[] {10, nHiddenUnits, 1}); + final INDArray dataSingleExampleLength12 = Nd4j.ones(1, nIn, 12); + final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations3.shape(), new long[] {1, nHiddenUnits, 12}); - final INDArray dataSingleExampleLength12 = Nd4j.ones(1, nIn, 12); - final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations3.shape(), new long[] {1, nHiddenUnits, 12}); + final INDArray dataMultiExampleLength15 = Nd4j.ones(10, nIn, 15); + final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15}); + } + else{ + final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, 1, nIn); + final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations1.shape(), new long[] {1, 1, nHiddenUnits}); + + final INDArray dataMultiExampleLength1 = Nd4j.ones(10, 1, nIn); + final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations2.shape(), new long[] {10, 1, nHiddenUnits}); + + final INDArray dataSingleExampleLength12 = Nd4j.ones(1, 12, nIn); + final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations3.shape(), new long[] {1, 12, nHiddenUnits}); + + final INDArray dataMultiExampleLength15 = Nd4j.ones(10, 15, nIn); + final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations4.shape(), new long[] {10, 15, nHiddenUnits}); + } - final INDArray dataMultiExampleLength15 = Nd4j.ones(10, nIn, 15); - final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15}); } @Test @@ -94,14 +124,15 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { testGravesBackwardBasicHelper(13, 3, 17, 1, 1); //Edge case: both miniBatchSize = 1 and timeSeriesLength = 1 } - private static void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, + private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) { - INDArray inputData = Nd4j.ones(miniBatchSize, nIn, timeSeriesLength); + INDArray inputData = (rnnDataFormat == RNNFormat.NCW)?Nd4j.ones(miniBatchSize, nIn, timeSeriesLength): + Nd4j.ones(miniBatchSize, timeSeriesLength, nIn); NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) - .nOut(lstmNHiddenUnits) + .nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat) .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) .build(); @@ -114,7 +145,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces()); assertNotNull(lstm.input()); - INDArray epsilon = Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength); + INDArray epsilon =(rnnDataFormat == RNNFormat.NCW)? Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength): + Nd4j.ones(miniBatchSize, timeSeriesLength, lstmNHiddenUnits); Pair out = lstm.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); Gradient outGradient = out.getFirst(); @@ -147,7 +179,11 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { assertArrayEquals(recurrentWeightGradientB.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3}); assertNotNull(nextEpsilon); - assertArrayEquals(nextEpsilon.shape(), new long[] {miniBatchSize, nIn, timeSeriesLength}); + if (rnnDataFormat == RNNFormat.NCW) { + assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, nIn, timeSeriesLength}); + }else{ + assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, timeSeriesLength, nIn }); + } //Check update: for (String s : outGradient.gradientForVariable().keySet()) { @@ -226,7 +262,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) - .nOut(layerSize) + .nOut(layerSize).dataFormat(rnnDataFormat) .dist(new UniformDistribution(-0.1, 0.1)).activation(Activation.TANH).build()) .build(); @@ -237,7 +273,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { .instantiate(confBidirectional, null, 0, params, true, params.dataType()); - final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); + final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}): + Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn}); final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces()); @@ -265,13 +302,13 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder() - .nIn(nIn).nOut(layerSize) + .nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat) .dist(new UniformDistribution(-0.1, 0.1)) .activation(Activation.TANH).updater(new NoOp()).build()) .build(); final NeuralNetConfiguration confForwards = new NeuralNetConfiguration.Builder() - .layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize) + .layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat) .weightInit(WeightInit.ZERO).activation(Activation.TANH).build()) .build(); @@ -290,9 +327,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { Nd4j.create(1, confForwards.getLayer().initializer().numParams(confForwards))); - final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); + final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}): + Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn}); final INDArray sigb = sig.dup(); - reverseColumnsInPlace(sigb.slice(0)); + + if (rnnDataFormat == RNNFormat.NCW) { + reverseColumnsInPlace(sigb.slice(0)); + } + else{ + reverseColumnsInPlace(sigb.slice(0).permute(1, 0)); + } final INDArray recurrentWeightsF = bidirectionalLSTM .getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS); @@ -345,10 +389,14 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { assertArrayEquals(activation1.data().asFloat(), activation2.data().asFloat(), 1e-5f); - final INDArray randSig = Nd4j.rand(new int[] {1, layerSize, timeSeriesLength}); - final INDArray randSigBackwards = randSig.dup(); - reverseColumnsInPlace(randSigBackwards.slice(0)); - + final INDArray randSig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {1, layerSize, timeSeriesLength}): + Nd4j.rand(new int[] {1, timeSeriesLength, layerSize}); + INDArray randSigBackwards = randSig.dup(); + if (rnnDataFormat == RNNFormat.NCW){ + reverseColumnsInPlace(randSigBackwards.slice(0)); + }else{ + reverseColumnsInPlace(randSigBackwards.slice(0).permute(1, 0)); + } final Pair backprop1 = forwardsLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces()); final Pair backprop2 = bidirectionalLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces()); @@ -399,10 +447,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { final INDArray activation3 = bidirectionalLSTM.activate(sigb, false, LayerWorkspaceMgr.noWorkspaces()).slice(0); final INDArray activation3Reverse = activation3.dup(); - reverseColumnsInPlace(activation3Reverse); + if (rnnDataFormat == RNNFormat.NCW){ + reverseColumnsInPlace(activation3Reverse); + } + else{ + reverseColumnsInPlace(activation3Reverse.permute(1, 0)); + } - assertEquals(activation3Reverse, activation1); assertArrayEquals(activation3Reverse.shape(), activation1.shape()); + assertEquals(activation3Reverse, activation1); + //test backprop now final INDArray refBackGradientReccurrent = @@ -434,7 +488,12 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { final INDArray refEpsilon = backprop1.getSecond().dup(); final INDArray backEpsilon = backprop3.getSecond().dup(); - reverseColumnsInPlace(refEpsilon.slice(0)); + if (rnnDataFormat == RNNFormat.NCW) { + reverseColumnsInPlace(refEpsilon.slice(0)); + } + else{ + reverseColumnsInPlace(refEpsilon.slice(0).permute(1, 0)); + } assertArrayEquals(backEpsilon.dup().data().asDouble(), refEpsilon.dup().data().asDouble(), 1e-6); } @@ -477,10 +536,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .seed(12345).list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder() - .gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2) + .gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat) .build()) .layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2) + .lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat) .activation(Activation.TANH).build()) .build(); @@ -492,7 +551,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { INDArray in = Nd4j.rand(new int[] {3, 2, 5}); INDArray labels = Nd4j.rand(new int[] {3, 2, 5}); - + if (rnnDataFormat == RNNFormat.NWC){ + in = in.permute(0, 2, 1); + labels = labels.permute(0, 2, 1); + } net.fit(in, labels); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java index 3ed7e1f3b..7ddc31220 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java @@ -21,11 +21,14 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.TrainingListener; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -36,9 +39,17 @@ import java.util.Collections; import static org.junit.Assert.assertEquals; - +@RunWith(Parameterized.class) public class MaskZeroLayerTest extends BaseDL4JTest { + private RNNFormat rnnDataFormat; + public MaskZeroLayerTest(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + } + @Parameterized.Parameters + public static Object[] params(){ + return RNNFormat.values(); + } @Test public void activate() { @@ -57,7 +68,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest { .activation(Activation.IDENTITY) .gateActivationFunction(Activation.IDENTITY) .nIn(2) - .nOut(1) + .nOut(1).dataFormat(rnnDataFormat) .build(); NeuralNetConfiguration conf = new NeuralNetConfiguration(); conf.setLayer(underlying); @@ -72,9 +83,14 @@ public class MaskZeroLayerTest extends BaseDL4JTest { MaskZeroLayer l = new MaskZeroLayer(lstm, maskingValue); INDArray input = Nd4j.create(Arrays.asList(ex1, ex2), new int[]{2, 2, 3}); + if (rnnDataFormat == RNNFormat.NWC){ + input = input.permute(0, 2, 1); + } //WHEN INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces()); - + if (rnnDataFormat == RNNFormat.NWC){ + out = out.permute(0, 2,1); + } //THEN output should only be incremented for the non-zero timesteps INDArray firstExampleOutput = out.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()); INDArray secondExampleOutput = out.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()); @@ -94,7 +110,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .list() .layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder() - .setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).build()).build()) + .setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java new file mode 100644 index 000000000..43dd93f56 --- /dev/null +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java @@ -0,0 +1,394 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.layers.recurrent; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.TestUtils; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM; +import org.deeplearning4j.nn.conf.layers.GravesLSTM; +import org.deeplearning4j.nn.conf.layers.LSTM; +import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; +import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; +import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; +import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +@RunWith(Parameterized.class) +@AllArgsConstructor +public class RnnDataFormatTests extends BaseDL4JTest { + + private boolean helpers; + private boolean lastTimeStep; + private boolean maskZeros; + + @Parameterized.Parameters(name = "helpers={0},lastTimeStep={1},maskZero={2}") + public static List params(){ + List ret = new ArrayList<>(); + for (boolean helpers: new boolean[]{true, false}) + for (boolean lastTimeStep: new boolean[]{true, false}) + for (boolean maskZero: new boolean[]{true, false}) + ret.add(new Object[]{helpers, lastTimeStep, maskZero}); + return ret; + } + + + @Test + public void testSimpleRnn() { + try { + + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12); + + INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getSimpleRnnNet(RNNFormat.NCW, true, lastTimeStep, maskZeros)) + .net2(getSimpleRnnNet(RNNFormat.NCW, false, lastTimeStep, maskZeros)) + .net3(getSimpleRnnNet(RNNFormat.NWC, true, lastTimeStep, maskZeros)) + .net4(getSimpleRnnNet(RNNFormat.NWC, false, lastTimeStep, maskZeros)) + .inNCW(inNCW) + .labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1)) + .labelsNWC(labelsNWC) + .testLayerIdx(1) + .build(); + + TestCase.testHelper(tc); + + + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testLSTM() { + try { + + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12); + + INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros)) + .net2(getLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros)) + .net3(getLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros)) + .net4(getLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros)) + .inNCW(inNCW) + .labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1)) + .labelsNWC(labelsNWC) + .testLayerIdx(1) + .build(); + + TestCase.testHelper(tc); + + + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + + @Test + public void testGraveLSTM() { + try { + + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12); + + INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getGravesLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros)) + .net2(getGravesLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros)) + .net3(getGravesLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros)) + .net4(getGravesLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros)) + .inNCW(inNCW) + .labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1)) + .labelsNWC(labelsNWC) + .testLayerIdx(1) + .build(); + + TestCase.testHelper(tc); + + + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + + @Test + public void testGraveBiLSTM() { + try { + + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12); + + INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getGravesBidirectionalLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros)) + .net2(getGravesBidirectionalLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros)) + .net3(getGravesBidirectionalLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros)) + .net4(getGravesBidirectionalLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros)) + .inNCW(inNCW) + .labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1)) + .labelsNWC(labelsNWC) + .testLayerIdx(1) + .build(); + + TestCase.testHelper(tc); + + + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + + private MultiLayerNetwork getGravesBidirectionalLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) { + if (setOnLayerAlso) { + return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3) + .dataFormat(format).build(), format, lastTimeStep, maskZeros); + } else { + return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros); + } + } + private MultiLayerNetwork getGravesLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) { + if (setOnLayerAlso) { + return getNetWithLayer(new GravesLSTM.Builder().nOut(3) + .dataFormat(format).build(), format, lastTimeStep, maskZeros); + } else { + return getNetWithLayer(new GravesLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros); + } + } + + private MultiLayerNetwork getLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) { + if (setOnLayerAlso) { + return getNetWithLayer(new LSTM.Builder().nOut(3) + .dataFormat(format).build(), format, lastTimeStep, maskZeros); + } else { + return getNetWithLayer(new LSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros); + } + } + + private MultiLayerNetwork getSimpleRnnNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) { + if (setOnLayerAlso) { + return getNetWithLayer(new SimpleRnn.Builder().nOut(3) + .dataFormat(format).build(), format, lastTimeStep, maskZeros); + } else { + return getNetWithLayer(new SimpleRnn.Builder().nOut(3).build(), format, lastTimeStep, maskZeros); + } + } + private MultiLayerNetwork getNetWithLayer(Layer layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) { + if (maskZeros){ + layer = new MaskZeroLayer.Builder().setMaskValue(0.).setUnderlying(layer).build(); + } + if(lastTimeStep){ + layer = new LastTimeStep(layer); + } + NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() + .seed(12345) + .list() + .layer(new LSTM.Builder() + .nIn(3) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build()) + .layer(layer) + .layer( + (lastTimeStep)?new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build(): + new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).dataFormat(format).build() + ) + .setInputType(InputType.recurrent(3, 12, format)); + + MultiLayerNetwork net = new MultiLayerNetwork(builder.build()); + net.init(); + return net; + } + + @AllArgsConstructor + @Data + @NoArgsConstructor + @Builder + private static class TestCase { + private String msg; + private MultiLayerNetwork net1; + private MultiLayerNetwork net2; + private MultiLayerNetwork net3; + private MultiLayerNetwork net4; + private INDArray inNCW; + private INDArray labelsNCW; + private INDArray labelsNWC; + private int testLayerIdx; + private boolean nwcOutput; + + public static void testHelper(TestCase tc) { + + tc.net2.params().assign(tc.net1.params()); + tc.net3.params().assign(tc.net1.params()); + tc.net4.params().assign(tc.net1.params()); + + INDArray inNCW = tc.inNCW; + INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup(); + + INDArray l0_1 = tc.net1.feedForward(inNCW).get(tc.testLayerIdx + 1); + INDArray l0_2 = tc.net2.feedForward(inNCW).get(tc.testLayerIdx + 1); + INDArray l0_3 = tc.net3.feedForward(inNWC).get(tc.testLayerIdx + 1); + INDArray l0_4 = tc.net4.feedForward(inNWC).get(tc.testLayerIdx + 1); + + boolean rank3Out = tc.labelsNCW.rank() == 3; + assertEquals(tc.msg, l0_1, l0_2); + if (rank3Out){ + assertEquals(tc.msg, l0_1, l0_3.permute(0, 2, 1)); + assertEquals(tc.msg, l0_1, l0_4.permute(0, 2, 1)); + } + else{ + assertEquals(tc.msg, l0_1, l0_3); + assertEquals(tc.msg, l0_1, l0_4); + } + INDArray out1 = tc.net1.output(inNCW); + INDArray out2 = tc.net2.output(inNCW); + INDArray out3 = tc.net3.output(inNWC); + INDArray out4 = tc.net4.output(inNWC); + + assertEquals(tc.msg, out1, out2); + if (rank3Out){ + assertEquals(tc.msg, out1, out3.permute(0, 2, 1)); //NWC to NCW + assertEquals(tc.msg, out1, out4.permute(0, 2, 1)); + } + else{ + assertEquals(tc.msg, out1, out3); //NWC to NCW + assertEquals(tc.msg, out1, out4); + } + + + //Test backprop + Pair p1 = tc.net1.calculateGradients(inNCW, tc.labelsNCW, null, null); + Pair p2 = tc.net2.calculateGradients(inNCW, tc.labelsNCW, null, null); + Pair p3 = tc.net3.calculateGradients(inNWC, tc.labelsNWC, null, null); + Pair p4 = tc.net4.calculateGradients(inNWC, tc.labelsNWC, null, null); + + //Inpput gradients + assertEquals(tc.msg, p1.getSecond(), p2.getSecond()); + + assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0, 2, 1)); //Input gradients for NWC input are also in NWC format + assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0, 2, 1)); + + + List diff12 = differentGrads(p1.getFirst(), p2.getFirst()); + List diff13 = differentGrads(p1.getFirst(), p3.getFirst()); + List diff14 = differentGrads(p1.getFirst(), p4.getFirst()); + assertEquals(tc.msg + " " + diff12, 0, diff12.size()); + assertEquals(tc.msg + " " + diff13, 0, diff13.size()); + assertEquals(tc.msg + " " + diff14, 0, diff14.size()); + + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable()); + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable()); + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable()); + + tc.net1.fit(inNCW, tc.labelsNCW); + tc.net2.fit(inNCW, tc.labelsNCW); + tc.net3.fit(inNWC, tc.labelsNWC); + tc.net4.fit(inNWC, tc.labelsNWC); + + assertEquals(tc.msg, tc.net1.params(), tc.net2.params()); + assertEquals(tc.msg, tc.net1.params(), tc.net3.params()); + assertEquals(tc.msg, tc.net1.params(), tc.net4.params()); + + //Test serialization + MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1); + MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2); + MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3); + MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4); + + out1 = tc.net1.output(inNCW); + assertEquals(tc.msg, out1, net1a.output(inNCW)); + assertEquals(tc.msg, out1, net2a.output(inNCW)); + + if (rank3Out) { + assertEquals(tc.msg, out1, net3a.output(inNWC).permute(0, 2, 1)); //NWC to NCW + assertEquals(tc.msg, out1, net4a.output(inNWC).permute(0, 2, 1)); + } + else{ + assertEquals(tc.msg, out1, net3a.output(inNWC)); //NWC to NCW + assertEquals(tc.msg, out1, net4a.output(inNWC)); + } + } + + } + private static List differentGrads(Gradient g1, Gradient g2){ + List differs = new ArrayList<>(); + Map m1 = g1.gradientForVariable(); + Map m2 = g2.gradientForVariable(); + for(String s : m1.keySet()){ + INDArray a1 = m1.get(s); + INDArray a2 = m2.get(s); + if(!a1.equals(a2)){ + differs.add(s); + } + } + return differs; + } +} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java index 7dd965ffb..9f60d674d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java @@ -21,6 +21,7 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.LSTM; @@ -29,6 +30,8 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.graph.ComputationGraph; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -42,14 +45,25 @@ import static org.nd4j.linalg.activations.Activation.IDENTITY; import static org.nd4j.linalg.activations.Activation.TANH; import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE; + +@RunWith(Parameterized.class) public class TestLastTimeStepLayer extends BaseDL4JTest { + private RNNFormat rnnDataFormat; + + public TestLastTimeStepLayer(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + } + @Parameterized.Parameters(name="{0}") + public static Object[] params(){ + return RNNFormat.values(); + } @Test public void testLastTimeStepVertex() { ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") .addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder() - .nIn(5).nOut(6).build()), "in") + .nIn(5).nOut(6).dataFormat(rnnDataFormat).build()), "in") .setOutputs("lastTS") .build(); @@ -59,9 +73,22 @@ public class TestLastTimeStepLayer extends BaseDL4JTest { //First: test without input mask array Nd4j.getRandom().setSeed(12345); Layer l = graph.getLayer("lastTS"); - INDArray in = Nd4j.rand(new int[]{3, 5, 6}); + INDArray in; + if (rnnDataFormat == RNNFormat.NCW){ + in = Nd4j.rand(3, 5, 6); + } + else{ + in = Nd4j.rand(3, 6, 5); + } INDArray outUnderlying = ((LastTimeStepLayer)l).getUnderlying().activate(in, false, LayerWorkspaceMgr.noWorkspaces()); - INDArray expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5)); + INDArray expOut; + if (rnnDataFormat == RNNFormat.NCW){ + expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5)); + } + else{ + expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.point(5), NDArrayIndex.all()); + } + //Forward pass: @@ -76,9 +103,17 @@ public class TestLastTimeStepLayer extends BaseDL4JTest { graph.setLayerMaskArrays(new INDArray[]{inMask}, null); expOut = Nd4j.zeros(3, 6); - expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2))); - expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3))); - expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4))); + if (rnnDataFormat == RNNFormat.NCW){ + expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2))); + expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3))); + expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4))); + } + else{ + expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.point(2), NDArrayIndex.all())); + expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.point(3), NDArrayIndex.all())); + expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.point(4), NDArrayIndex.all())); + } + outFwd = l.activate(in, false, LayerWorkspaceMgr.noWorkspaces()); assertEquals(expOut, outFwd); @@ -97,9 +132,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest { .seed(1234) .graphBuilder() .addInputs("in") - .setInputTypes(InputType.recurrent(1)) + .setInputTypes(InputType.recurrent(1, rnnDataFormat)) .addLayer("RNN", new LastTimeStep(new LSTM.Builder() - .nOut(10) + .nOut(10).dataFormat(rnnDataFormat) .build()), "in") .addLayer("dense", new DenseLayer.Builder() .nOut(10) @@ -120,7 +155,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest { INDArray fm2 = Nd4j.zeros(1,24); INDArray fm3 = Nd4j.zeros(1,24); fm3.get(NDArrayIndex.point(0), NDArrayIndex.interval(0,5)).assign(1); - + if (rnnDataFormat == RNNFormat.NWC){ + f = f.permute(0, 2, 1); + } INDArray[] out1 = cg.output(false, new INDArray[]{f}, new INDArray[]{fm1}); try { cg.output(false, new INDArray[]{f}, new INDArray[]{fm2}); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java index 11e45c51d..5b42d95bc 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java @@ -20,6 +20,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.dropout.TestDropout; import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.LSTM; @@ -31,6 +32,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -46,8 +49,18 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; +@RunWith(Parameterized.class) public class TestRnnLayers extends BaseDL4JTest { + private RNNFormat rnnDataFormat; + + public TestRnnLayers(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + } + @Parameterized.Parameters + public static Object[] params(){ + return RNNFormat.values(); + } @Test public void testTimeStepIs3Dimensional() { @@ -58,8 +71,8 @@ public class TestRnnLayers extends BaseDL4JTest { .updater(new NoOp()) .weightInit(WeightInit.XAVIER) .list() - .layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).build()) - .layer(new LSTM.Builder().nIn(3).nOut(5).build()) + .layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).dataFormat(rnnDataFormat).build()) + .layer(new LSTM.Builder().nIn(3).nOut(5).dataFormat(rnnDataFormat).build()) .layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).build()) .build(); @@ -70,9 +83,9 @@ public class TestRnnLayers extends BaseDL4JTest { org.deeplearning4j.nn.layers.recurrent.SimpleRnn simpleRnn = (org.deeplearning4j.nn.layers.recurrent.SimpleRnn) net.getLayer(0); - INDArray rnnInput3d = Nd4j.create(10, 12, 1); + INDArray rnnInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10,12, 1):Nd4j.create(10, 1, 12); INDArray simpleOut = simpleRnn.rnnTimeStep(rnnInput3d, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(simpleOut.shape(), new long[] {10, 3, 1})); + assertTrue(Arrays.equals(simpleOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 3, 1}:new long[]{10, 1, 3})); INDArray rnnInput2d = Nd4j.create(10, 12); try { @@ -84,9 +97,9 @@ public class TestRnnLayers extends BaseDL4JTest { org.deeplearning4j.nn.layers.recurrent.LSTM lstm = (org.deeplearning4j.nn.layers.recurrent.LSTM) net.getLayer(1); - INDArray lstmInput3d = Nd4j.create(10, 3, 1); + INDArray lstmInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10, 3, 1):Nd4j.create(10, 1, 3); INDArray lstmOut = lstm.rnnTimeStep(lstmInput3d, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(lstmOut.shape(), new long[] {10, 5, 1})); + assertTrue(Arrays.equals(lstmOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 5, 1}:new long[]{10, 1, 5})); INDArray lstmInput2d = Nd4j.create(10, 3); try { @@ -112,19 +125,19 @@ public class TestRnnLayers extends BaseDL4JTest { TestDropout.CustomDropout cd = new TestDropout.CustomDropout(); switch (s){ case "graves": - layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build(); - layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build(); - layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build(); + layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); + layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); + layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); break; case "lstm": - layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build(); - layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build(); - layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build(); + layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); + layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); + layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); break; case "simple": - layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).build(); - layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build(); - layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build(); + layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); + layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); + layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); break; default: throw new RuntimeException(s); @@ -134,21 +147,21 @@ public class TestRnnLayers extends BaseDL4JTest { .seed(12345) .list() .layer(layer) - .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build()) + .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) .build(); MultiLayerConfiguration confD = new NeuralNetConfiguration.Builder() .seed(12345) .list() .layer(layerD) - .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build()) + .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) .build(); MultiLayerConfiguration confD2 = new NeuralNetConfiguration.Builder() .seed(12345) .list() .layer(layerD2) - .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build()) + .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -178,7 +191,6 @@ public class TestRnnLayers extends BaseDL4JTest { assertNotEquals(s, out2, out2D); INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345); - net.fit(f.dup(), l); netD.fit(f.dup(), l); assertNotEquals(s, net.params(), netD.params()); @@ -209,10 +221,10 @@ public class TestRnnLayers extends BaseDL4JTest { switch (i){ case 0: - lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).build()); + lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).dataFormat(rnnDataFormat).build()); break; case 1: - lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()); + lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).dataFormat(rnnDataFormat).build()); break; default: throw new RuntimeException(); @@ -224,13 +236,16 @@ public class TestRnnLayers extends BaseDL4JTest { INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5); INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10); - + if (rnnDataFormat == RNNFormat.NWC){ + l = l.permute(0, 2, 1); + } try{ net.fit(in,l); } catch (Throwable t){ String msg = t.getMessage(); if(msg == null) t.printStackTrace(); + System.out.println(i); assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label")); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java index bf8b964b1..7d61316d5 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java @@ -20,10 +20,13 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -36,8 +39,18 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.point; +@RunWith(Parameterized.class) public class TestSimpleRnn extends BaseDL4JTest { + private RNNFormat rnnDataFormat; + + public TestSimpleRnn(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + } + @Parameterized.Parameters + public static Object[] params(){ + return RNNFormat.values(); + } @Test public void testSimpleRnn(){ Nd4j.getRandom().setSeed(12345); @@ -46,7 +59,15 @@ public class TestSimpleRnn extends BaseDL4JTest { int nIn = 5; int layerSize = 6; int tsLength = 7; - INDArray in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength}); + INDArray in; + if (rnnDataFormat == RNNFormat.NCW){ + in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength}); + } + else{ + in = Nd4j.rand(DataType.FLOAT, new int[]{m, tsLength, nIn}); + } + + // in.get(all(), all(), interval(1,tsLength)).assign(0); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() @@ -54,7 +75,7 @@ public class TestSimpleRnn extends BaseDL4JTest { .weightInit(WeightInit.XAVIER) .activation(Activation.TANH) .list() - .layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).build()) + .layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -68,7 +89,13 @@ public class TestSimpleRnn extends BaseDL4JTest { INDArray outLast = null; for( int i=0; i 1) throw new InvalidKerasConfigurationException( - "Keras LSTM layer accepts only one input (received " + inputType.length + ")"); - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName); + "Keras Conv1D layer accepts only one input (received " + inputType.length + ")"); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], RNNFormat.NCW,layerName); } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java index 1c205bbca..0888b3376 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java @@ -22,11 +22,9 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; -import org.deeplearning4j.nn.conf.layers.InputTypeUtil; -import org.deeplearning4j.nn.conf.layers.LSTM; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; @@ -37,6 +35,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.params.LSTMParamInitializer; import org.deeplearning4j.nn.weights.IWeightInit; +import org.deeplearning4j.util.TimeSeriesUtils; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -266,7 +265,8 @@ public class KerasLSTM extends KerasLayer { throw new InvalidKerasConfigurationException("Keras LSTM layer accepts only one single input" + "or three (input to LSTM and two states tensors, but " + "received " + inputType.length + "."); - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName); + RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f,layerName); } /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java index f6ecbb6a5..67fe611e1 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java @@ -21,7 +21,9 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.Layer; @@ -36,6 +38,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.params.SimpleRnnParamInitializer; import org.deeplearning4j.nn.weights.IWeightInit; +import org.deeplearning4j.util.TimeSeriesUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; @@ -227,7 +230,8 @@ public class KerasSimpleRnn extends KerasLayer { throw new InvalidKerasConfigurationException( "Keras SimpleRnn layer accepts only one input (received " + inputType.length + ")"); - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName); + RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f, layerName); } /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java index 3b7cb1721..c8cc4fc20 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java @@ -218,7 +218,7 @@ public class KerasBidirectional extends KerasLayer { if (inputType.length > 1) throw new InvalidKerasConfigurationException( "Keras Bidirectional layer accepts only one input (received " + inputType.length + ")"); - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], ((Bidirectional)layer).getRNNDataFormat(), layerName); } /** diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java index 98cce2d27..0fe2a8689 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.api.layers; import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.gradient.Gradient; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; @@ -98,4 +99,5 @@ public interface RecurrentLayer extends Layer { */ Pair tbpttBackpropGradient(INDArray epsilon, int tbpttBackLength, LayerWorkspaceMgr workspaceMgr); + } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/RNNFormat.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/RNNFormat.java new file mode 100644 index 000000000..186b405e7 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/RNNFormat.java @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.deeplearning4j.nn.conf; + +/** + * NCW = "channels first" - arrays of shape [minibatch, channels, width]
+ * NWC = "channels last" - arrays of shape [minibatch, width, channels]
+ * "width" corresponds to sequence length and "channels" corresponds to sequence item size. + */ + +public enum RNNFormat { + NCW, + NWC +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java index cc9622905..2c7a4e5f8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java @@ -20,6 +20,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.NoArgsConstructor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.nd4j.linalg.api.ndarray.INDArray; @@ -111,9 +112,16 @@ public abstract class InputType implements Serializable { * @return InputTypeRecurrent */ public static InputType recurrent(long size, long timeSeriesLength) { - return new InputTypeRecurrent(size, timeSeriesLength); + return new InputTypeRecurrent(size, timeSeriesLength, RNNFormat.NCW); } + public static InputType recurrent(long size, RNNFormat format){ + return new InputTypeRecurrent(size, format); + } + + public static InputType recurrent(long size, long timeSeriesLength, RNNFormat format){ + return new InputTypeRecurrent(size, timeSeriesLength, format); + } /** * Input type for convolutional (CNN) data, that is 4d with shape [miniBatchSize, channels, height, width]. * For CNN data that has been flattened, use {@link #convolutionalFlat(long, long, long)} @@ -216,14 +224,23 @@ public abstract class InputType implements Serializable { public static class InputTypeRecurrent extends InputType { private long size; private long timeSeriesLength; - + private RNNFormat format = RNNFormat.NCW; public InputTypeRecurrent(long size) { this(size, -1); } + public InputTypeRecurrent(long size, long timeSeriesLength){ + this(size, timeSeriesLength, RNNFormat.NCW); + } - public InputTypeRecurrent(@JsonProperty("size") long size, @JsonProperty("timeSeriesLength") long timeSeriesLength) { + public InputTypeRecurrent(long size, RNNFormat format){ + this(size, -1, format); + } + public InputTypeRecurrent(@JsonProperty("size") long size, + @JsonProperty("timeSeriesLength") long timeSeriesLength, + @JsonProperty("format") RNNFormat format) { this.size = size; this.timeSeriesLength = timeSeriesLength; + this.format = format; } @Override @@ -234,9 +251,9 @@ public abstract class InputType implements Serializable { @Override public String toString() { if (timeSeriesLength > 0) { - return "InputTypeRecurrent(" + size + ",timeSeriesLength=" + timeSeriesLength + ")"; + return "InputTypeRecurrent(" + size + ",timeSeriesLength=" + timeSeriesLength + ",format=" + format + ")"; } else { - return "InputTypeRecurrent(" + size + ")"; + return "InputTypeRecurrent(" + size + ",format=" + format + ")"; } } @@ -251,8 +268,23 @@ public abstract class InputType implements Serializable { @Override public long[] getShape(boolean includeBatchDim) { - if(includeBatchDim) return new long[]{-1, size, timeSeriesLength}; - else return new long[]{size, timeSeriesLength}; + if (includeBatchDim){ + if (format == RNNFormat.NCW){ + return new long[]{-1, size, timeSeriesLength}; + } + else{ + return new long[]{-1, timeSeriesLength, size}; + } + + } + else{ + if (format == RNNFormat.NCW){ + return new long[]{size, timeSeriesLength}; + } + else{ + return new long[]{timeSeriesLength, size}; + } + } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java index 07bb3d674..fda5aba83 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.weights.IWeightInit; @@ -35,10 +36,12 @@ import java.util.List; public abstract class BaseRecurrentLayer extends FeedForwardLayer { protected IWeightInit weightInitFnRecurrent; + protected RNNFormat rnnDataFormat = RNNFormat.NCW; protected BaseRecurrentLayer(Builder builder) { super(builder); this.weightInitFnRecurrent = builder.weightInitFnRecurrent; + this.rnnDataFormat = builder.rnnDataFormat; } @Override @@ -51,7 +54,7 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer { InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType; - return InputType.recurrent(nOut, itr.getTimeSeriesLength()); + return InputType.recurrent(nOut, itr.getTimeSeriesLength(), itr.getFormat()); } @Override @@ -64,12 +67,13 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer { if (nIn <= 0 || override) { InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; this.nIn = r.getSize(); + this.rnnDataFormat = r.getFormat(); } } @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, rnnDataFormat,getLayerName()); } @NoArgsConstructor @@ -77,6 +81,12 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer { @Setter public static abstract class Builder> extends FeedForwardLayer.Builder { + /** + * Set the format of data expected by the RNN. NCW = [miniBatchSize, size, timeSeriesLength], + * NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW. + */ + protected RNNFormat rnnDataFormat = RNNFormat.NCW; + /** * Set constraints to be applied to the RNN recurrent weight parameters of this layer. Default: no * constraints.
Constraints can be used to enforce certain conditions (non-negativity of parameters, @@ -163,5 +173,10 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer { this.setWeightInitFnRecurrent(new WeightInitDistribution(dist)); return (T) this; } + + public T dataFormat(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + return (T)this; + } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java index b220ba5a6..32b57f34b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java @@ -22,6 +22,7 @@ import lombok.NoArgsConstructor; import lombok.ToString; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.Convolution1DUtils; @@ -114,7 +115,7 @@ public class Convolution1DLayer extends ConvolutionLayer { + "\"): input is null"); } - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName()); } public static class Builder extends ConvolutionLayer.BaseConvBuilder { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java index e15a41781..9b49f6415 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java @@ -87,7 +87,7 @@ public abstract class FeedForwardLayer extends BaseLayer { return null; case RNN: //RNN -> FF - return new RnnToFeedForwardPreProcessor(); + return new RnnToFeedForwardPreProcessor(((InputType.InputTypeRecurrent)inputType).getFormat()); case CNN: //CNN -> FF InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java index 655f0e880..06a71f01d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java @@ -22,6 +22,7 @@ import org.deeplearning4j.exception.DL4JInvalidConfigException; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor; @@ -528,7 +529,7 @@ public class InputTypeUtil { } } - public static InputPreProcessor getPreprocessorForInputTypeRnnLayers(InputType inputType, String layerName) { + public static InputPreProcessor getPreprocessorForInputTypeRnnLayers(InputType inputType, RNNFormat rnnDataFormat, String layerName) { if (inputType == null) { throw new IllegalStateException( "Invalid input for RNN layer (layer name = \"" + layerName + "\"): input type is null"); @@ -539,14 +540,14 @@ public class InputTypeUtil { case CNNFlat: //FF -> RNN or CNNFlat -> RNN //In either case, input data format is a row vector per example - return new FeedForwardToRnnPreProcessor(); + return new FeedForwardToRnnPreProcessor(rnnDataFormat); case RNN: //RNN -> RNN: No preprocessor necessary return null; case CNN: //CNN -> RNN InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; - return new CnnToRnnPreProcessor(c.getHeight(), c.getWidth(), c.getChannels()); + return new CnnToRnnPreProcessor(c.getHeight(), c.getWidth(), c.getChannels(), rnnDataFormat); default: throw new RuntimeException("Unknown input type: " + inputType); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java index 0c3c6e383..a0fa4d680 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer; @@ -86,7 +87,7 @@ public class LearnedSelfAttentionLayer extends SameDiffLayer { @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java index 8fedee7b0..d43423bf4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java @@ -20,6 +20,7 @@ import lombok.*; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer; @@ -136,7 +137,7 @@ public class LocallyConnected1D extends SameDiffLayer { @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java index d12e0ec74..bc746f8ab 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer; @@ -92,7 +93,7 @@ public class RecurrentAttentionLayer extends SameDiffLayer { @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java index df0b16e6c..f1dcd73a6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; @@ -53,12 +54,13 @@ import java.util.Map; @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) public class RnnLossLayer extends FeedForwardLayer { - + private RNNFormat rnnDataFormat = RNNFormat.NCW; protected ILossFunction lossFn; private RnnLossLayer(Builder builder) { super(builder); this.setLossFn(builder.lossFn); + this.rnnDataFormat = builder.rnnDataFormat; } @Override @@ -91,7 +93,7 @@ public class RnnLossLayer extends FeedForwardLayer { @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName()); } @Override @@ -111,8 +113,9 @@ public class RnnLossLayer extends FeedForwardLayer { public static class Builder extends BaseOutputLayer.Builder { - public Builder() { + private RNNFormat rnnDataFormat = RNNFormat.NCW; + public Builder() { } /** @@ -153,6 +156,14 @@ public class RnnLossLayer extends FeedForwardLayer { "This layer has no parameters, thus nIn will always equal nOut."); } + /** + * @param rnnDataFormat Data format expected by the layer. NCW = [miniBatchSize, size, timeSeriesLength], + * NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW. + */ + public Builder dataFormat(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + return this; + } @Override @SuppressWarnings("unchecked") public RnnLossLayer build() { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java index 078673f5d..ec0ecf59c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; @@ -51,9 +52,11 @@ import java.util.Map; @EqualsAndHashCode(callSuper = true) public class RnnOutputLayer extends BaseOutputLayer { + private RNNFormat rnnDataFormat = RNNFormat.NCW; private RnnOutputLayer(Builder builder) { super(builder); initializeConstraints(builder); + this.rnnDataFormat = builder.rnnDataFormat; } @Override @@ -85,7 +88,7 @@ public class RnnOutputLayer extends BaseOutputLayer { } InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType; - return InputType.recurrent(nOut, itr.getTimeSeriesLength()); + return InputType.recurrent(nOut, itr.getTimeSeriesLength(), itr.getFormat()); } @Override @@ -97,18 +100,20 @@ public class RnnOutputLayer extends BaseOutputLayer { if (nIn <= 0 || override) { InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; + this.rnnDataFormat = r.getFormat(); this.nIn = r.getSize(); } } @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, rnnDataFormat, getLayerName()); } public static class Builder extends BaseOutputLayer.Builder { + private RNNFormat rnnDataFormat = RNNFormat.NCW; public Builder() { //Set default activation function to softmax (to match default loss function MCXENT) this.setActivationFn(new ActivationSoftmax()); @@ -137,5 +142,14 @@ public class RnnOutputLayer extends BaseOutputLayer { public RnnOutputLayer build() { return new RnnOutputLayer(this); } + + /** + * @param rnnDataFormat Data format expected by the layer. NCW = [miniBatchSize, size, timeSeriesLength], + * NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW. + */ + public Builder dataFormat(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + return this; + } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java index db898dbdd..8daa4a2c2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer; @@ -75,7 +76,7 @@ public class SelfAttentionLayer extends SameDiffLayer { @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java index 9f3162374..4b2b959c2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java @@ -22,6 +22,7 @@ import lombok.NoArgsConstructor; import lombok.ToString; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.Convolution1DUtils; @@ -105,7 +106,7 @@ public class Subsampling1DLayer extends SubsamplingLayer { + "\"): input is null"); } - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java index e888a2904..a3345fde9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java @@ -20,6 +20,7 @@ import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; @@ -104,7 +105,7 @@ public class ZeroPadding1DLayer extends NoParamLayer { + "\"): input is null"); } - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java index 4d32f22e5..792e5633b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java @@ -21,6 +21,7 @@ import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; @@ -30,6 +31,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer; import org.deeplearning4j.nn.params.BidirectionalParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.deeplearning4j.util.TimeSeriesUtils; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; @@ -124,6 +126,10 @@ public class Bidirectional extends Layer { } } + public RNNFormat getRNNDataFormat(){ + return TimeSeriesUtils.getFormatFromRnnLayer(fwd); + } + @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, @@ -170,7 +176,7 @@ public class Bidirectional extends Layer { } else { InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) outOrig; if (mode == Mode.CONCAT) { - return InputType.recurrent(2 * r.getSize()); + return InputType.recurrent(2 * r.getSize(), getRNNDataFormat()); } else { return r; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java index bd9685ef9..5489ccc78 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java @@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode; import lombok.NonNull; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; @@ -29,17 +30,19 @@ import java.util.Collection; @EqualsAndHashCode(callSuper = true) public class TimeDistributed extends BaseWrapperLayer { - private final int timeAxis; + private RNNFormat rnnDataFormat = RNNFormat.NCW; /** * @param underlying Underlying (internal) layer - should be a feed forward type such as DenseLayer - * @param timeAxis Time axis, should be 2 for DL4J RNN activations (shape [minibatch, size, sequenceLength]) */ - public TimeDistributed(@JsonProperty("underlying") @NonNull Layer underlying, @JsonProperty("timeAxis") int timeAxis) { + public TimeDistributed(@JsonProperty("underlying") @NonNull Layer underlying, @JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) { super(underlying); - this.timeAxis = timeAxis; + this.rnnDataFormat = rnnDataFormat; } + public TimeDistributed(Layer underlying){ + super(underlying); + } @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, @@ -47,7 +50,7 @@ public class TimeDistributed extends BaseWrapperLayer { NeuralNetConfiguration conf2 = conf.clone(); conf2.setLayer(((TimeDistributed) conf2.getLayer()).getUnderlying()); return new TimeDistributedLayer(underlying.instantiate(conf2, trainingListeners, layerIndex, layerParamsView, - initializeParams, networkDataType), timeAxis); + initializeParams, networkDataType), rnnDataFormat); } @Override @@ -59,7 +62,7 @@ public class TimeDistributed extends BaseWrapperLayer { InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType; InputType ff = InputType.feedForward(rnn.getSize()); InputType ffOut = underlying.getOutputType(layerIndex, ff); - return InputType.recurrent(ffOut.arrayElementsPerExample(), rnn.getTimeSeriesLength()); + return InputType.recurrent(ffOut.arrayElementsPerExample(), rnn.getTimeSeriesLength(), rnnDataFormat); } @Override @@ -70,6 +73,7 @@ public class TimeDistributed extends BaseWrapperLayer { InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType; InputType ff = InputType.feedForward(rnn.getSize()); + this.rnnDataFormat = rnn.getFormat(); underlying.setNIn(ff, override); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java index 42dca9105..43b1b1e7c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.preprocessor; import lombok.*; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.util.TimeSeriesUtils; import org.nd4j.base.Preconditions; @@ -38,7 +39,7 @@ import java.util.Arrays; * Functionally equivalent to combining CnnToFeedForwardPreProcessor + FeedForwardToRnnPreProcessor
* Specifically, this does two things:
* (a) Reshape 4d activations out of CNN layer, with shape [timeSeriesLength*miniBatchSize, numChannels, inputHeight, inputWidth]) - * into 3d (time series) activations (with shape [numExamples, inputHeight*inputWidth*numChannels, timeSeriesLength]) + * into 3d (time series) activations (with shape [miniBatchSize, inputHeight*inputWidth*numChannels, timeSeriesLength]) * for use in RNN layers
* (b) Reshapes 3d epsilons (weights.*deltas) out of RNN layer (with shape * [miniBatchSize,inputHeight*inputWidth*numChannels,timeSeriesLength]) into 4d epsilons with shape @@ -52,6 +53,7 @@ public class CnnToRnnPreProcessor implements InputPreProcessor { private long inputHeight; private long inputWidth; private long numChannels; + private RNNFormat rnnDataFormat = RNNFormat.NCW; @Getter(AccessLevel.NONE) @Setter(AccessLevel.NONE) @@ -59,11 +61,20 @@ public class CnnToRnnPreProcessor implements InputPreProcessor { @JsonCreator public CnnToRnnPreProcessor(@JsonProperty("inputHeight") long inputHeight, - @JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) { + @JsonProperty("inputWidth") long inputWidth, + @JsonProperty("numChannels") long numChannels, + @JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) { this.inputHeight = inputHeight; this.inputWidth = inputWidth; this.numChannels = numChannels; this.product = inputHeight * inputWidth * numChannels; + this.rnnDataFormat = rnnDataFormat; + } + + public CnnToRnnPreProcessor(long inputHeight, + long inputWidth, + long numChannels){ + this(inputHeight, inputWidth, numChannels, RNNFormat.NCW); } @Override @@ -90,14 +101,19 @@ public class CnnToRnnPreProcessor implements InputPreProcessor { //Second: reshape 2d to 3d, as per FeedForwardToRnnPreProcessor INDArray reshaped = workspaceMgr.dup(ArrayType.ACTIVATIONS, twod, 'f'); reshaped = reshaped.reshape('f', miniBatchSize, shape[0] / miniBatchSize, product); - return reshaped.permute(0, 2, 1); + if (rnnDataFormat == RNNFormat.NCW) { + return reshaped.permute(0, 2, 1); + } + return reshaped; } @Override public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { if (output.ordering() == 'c' || !Shape.hasDefaultStridesForShape(output)) output = output.dup('f'); - + if (rnnDataFormat == RNNFormat.NWC){ + output = output.permute(0, 2, 1); + } val shape = output.shape(); INDArray output2d; if (shape[0] == 1) { @@ -122,7 +138,7 @@ public class CnnToRnnPreProcessor implements InputPreProcessor { @Override public CnnToRnnPreProcessor clone() { - return new CnnToRnnPreProcessor(inputHeight, inputWidth, numChannels); + return new CnnToRnnPreProcessor(inputHeight, inputWidth, numChannels, rnnDataFormat); } @Override @@ -133,7 +149,7 @@ public class CnnToRnnPreProcessor implements InputPreProcessor { InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; val outSize = c.getChannels() * c.getHeight() * c.getWidth(); - return InputType.recurrent(outSize); + return InputType.recurrent(outSize, rnnDataFormat); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java index aa45b30a5..7da79b935 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java @@ -21,6 +21,7 @@ import lombok.NoArgsConstructor; import lombok.val; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.util.TimeSeriesUtils; import org.nd4j.linalg.api.ndarray.INDArray; @@ -28,7 +29,7 @@ import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.ArrayType; - +import org.nd4j.shade.jackson.annotation.JsonProperty; import java.util.Arrays; /** @@ -48,7 +49,11 @@ import java.util.Arrays; @Data @NoArgsConstructor public class FeedForwardToRnnPreProcessor implements InputPreProcessor { + private RNNFormat rnnDataFormat = RNNFormat.NCW; + public FeedForwardToRnnPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + } @Override public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { //Need to reshape FF activations (2d) activations to 3d (for input into RNN layer) @@ -60,7 +65,10 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor { val shape = input.shape(); INDArray reshaped = input.reshape('f', miniBatchSize, shape[0] / miniBatchSize, shape[1]); - return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reshaped.permute(0, 2, 1)); + if (rnnDataFormat == RNNFormat.NCW){ + reshaped = reshaped.permute(0, 2, 1); + } + return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reshaped); } @Override @@ -71,6 +79,9 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor { "Invalid input: expect NDArray with rank 3 (i.e., epsilons from RNN layer)"); if (output.ordering() != 'f' || !Shape.hasDefaultStridesForShape(output)) output = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, output, 'f'); + if (rnnDataFormat == RNNFormat.NWC){ + output = output.permute(0, 2, 1); + } val shape = output.shape(); INDArray ret; @@ -87,12 +98,7 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor { @Override public FeedForwardToRnnPreProcessor clone() { - try { - FeedForwardToRnnPreProcessor clone = (FeedForwardToRnnPreProcessor) super.clone(); - return clone; - } catch (CloneNotSupportedException e) { - throw new RuntimeException(e); - } + return new FeedForwardToRnnPreProcessor(rnnDataFormat); } @Override @@ -104,10 +110,10 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor { if (inputType.getType() == InputType.Type.FF) { InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward) inputType; - return InputType.recurrent(ff.getSize()); + return InputType.recurrent(ff.getSize(), rnnDataFormat); } else { InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat) inputType; - return InputType.recurrent(cf.getFlattenedSize()); + return InputType.recurrent(cf.getFlattenedSize(), rnnDataFormat); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java index a3920b061..bcfc92170 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java @@ -19,8 +19,10 @@ package org.deeplearning4j.nn.conf.preprocessor; import lombok.*; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.util.TimeSeriesUtils; +import org.nd4j.enums.RnnDataFormat; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.primitives.Pair; @@ -52,19 +54,27 @@ public class RnnToCnnPreProcessor implements InputPreProcessor { private int inputHeight; private int inputWidth; private int numChannels; - + private RNNFormat rnnDataFormat = RNNFormat.NCW; @Getter(AccessLevel.NONE) @Setter(AccessLevel.NONE) private int product; public RnnToCnnPreProcessor(@JsonProperty("inputHeight") int inputHeight, - @JsonProperty("inputWidth") int inputWidth, @JsonProperty("numChannels") int numChannels) { + @JsonProperty("inputWidth") int inputWidth, + @JsonProperty("numChannels") int numChannels, + @JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) { this.inputHeight = inputHeight; this.inputWidth = inputWidth; this.numChannels = numChannels; this.product = inputHeight * inputWidth * numChannels; + this.rnnDataFormat = rnnDataFormat; } + public RnnToCnnPreProcessor(int inputHeight, + int inputWidth, + int numChannels){ + this(inputHeight, inputWidth, numChannels, RNNFormat.NCW); + } @Override public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { @@ -72,6 +82,9 @@ public class RnnToCnnPreProcessor implements InputPreProcessor { input = input.dup('f'); //Input: 3d activations (RNN) //Output: 4d activations (CNN) + if (rnnDataFormat == RNNFormat.NWC){ + input = input.permute(0, 2, 1); + } val shape = input.shape(); INDArray in2d; if (shape[0] == 1) { @@ -98,14 +111,17 @@ public class RnnToCnnPreProcessor implements InputPreProcessor { val shape = output.shape(); //First: reshape 4d to 2d INDArray twod = output.reshape('c', output.size(0), ArrayUtil.prod(output.shape()) / output.size(0)); - //Second: reshape 2d to 4d + //Second: reshape 2d to 3d INDArray reshaped = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, twod, 'f').reshape('f', miniBatchSize, shape[0] / miniBatchSize, product); - return reshaped.permute(0, 2, 1); + if (rnnDataFormat == RNNFormat.NCW) { + reshaped = reshaped.permute(0, 2, 1); + } + return reshaped; } @Override public RnnToCnnPreProcessor clone() { - return new RnnToCnnPreProcessor(inputHeight, inputWidth, numChannels); + return new RnnToCnnPreProcessor(inputHeight, inputWidth, numChannels, rnnDataFormat); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java index 7c92a7eaf..125aaf78b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java @@ -16,11 +16,14 @@ package org.deeplearning4j.nn.conf.preprocessor; +import lombok.AllArgsConstructor; import lombok.Data; +import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.util.TimeSeriesUtils; import org.nd4j.linalg.api.ndarray.INDArray; @@ -28,6 +31,7 @@ import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.ArrayType; +import org.nd4j.shade.jackson.annotation.JsonProperty; import java.util.Arrays; @@ -47,8 +51,14 @@ import java.util.Arrays; */ @Data @Slf4j +@NoArgsConstructor public class RnnToFeedForwardPreProcessor implements InputPreProcessor { + private RNNFormat rnnDataFormat = RNNFormat.NCW; + + public RnnToFeedForwardPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + } @Override public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { //Need to reshape RNN activations (3d) activations to 2d (for input into feed forward layer) @@ -59,10 +69,13 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor { if (input.ordering() != 'f' || !Shape.hasDefaultStridesForShape(input)) input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'f'); + if (rnnDataFormat == RNNFormat.NWC){ + input = input.permute(0, 2, 1); + } val shape = input.shape(); INDArray ret; if (shape[0] == 1) { - ret = input.tensorAlongDimension(0, 1, 2).permutei(1, 0); //Edge case: miniBatchSize==1 + ret = input.tensorAlongDimension(0, 1, 2).permute(1, 0); //Edge case: miniBatchSize==1 } else if (shape[2] == 1) { ret = input.tensorAlongDimension(0, 1, 0); //Edge case: timeSeriesLength=1 } else { @@ -85,17 +98,15 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor { val shape = output.shape(); INDArray reshaped = output.reshape('f', miniBatchSize, shape[0] / miniBatchSize, shape[1]); - return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, reshaped.permute(0, 2, 1)); + if (rnnDataFormat == RNNFormat.NCW){ + reshaped = reshaped.permute(0, 2, 1); + } + return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, reshaped); } @Override public RnnToFeedForwardPreProcessor clone() { - try { - RnnToFeedForwardPreProcessor clone = (RnnToFeedForwardPreProcessor) super.clone(); - return clone; - } catch (CloneNotSupportedException e) { - throw new RuntimeException(e); - } + return new RnnToFeedForwardPreProcessor(rnnDataFormat); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java index 265946bd8..5aa5bc88c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java @@ -18,7 +18,10 @@ package org.deeplearning4j.nn.layers.recurrent; import org.deeplearning4j.nn.api.layers.RecurrentLayer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; +import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.layers.BaseLayer; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -26,7 +29,7 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -public abstract class BaseRecurrentLayer +public abstract class BaseRecurrentLayer extends BaseLayer implements RecurrentLayer { /** @@ -85,4 +88,19 @@ public abstract class BaseRecurrentLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { INDArray eFwd; INDArray eBwd; - + boolean permute = getRNNDataFormat() == RNNFormat.NWC && epsilon.rank() == 3; + if (permute){ + epsilon = epsilon.permute(0, 2, 1); + } val n = epsilon.size(1)/2; switch (layerConf.getMode()){ case ADD: @@ -165,6 +172,10 @@ public class BidirectionalLayer implements RecurrentLayer { eBwd = TimeSeriesUtils.reverseTimeSeries(eBwd, workspaceMgr, ArrayType.BP_WORKING_MEM); + if (permute){ + eFwd = eFwd.permute(0, 2, 1); + eBwd = eBwd.permute(0, 2, 1); + } Pair g1 = fwd.backpropGradient(eFwd, workspaceMgr); Pair g2 = bwd.backpropGradient(eBwd, workspaceMgr); @@ -176,7 +187,9 @@ public class BidirectionalLayer implements RecurrentLayer { g.gradientForVariable().put(BidirectionalParamInitializer.BACKWARD_PREFIX + e.getKey(), e.getValue()); } - INDArray g2Reversed = TimeSeriesUtils.reverseTimeSeries(g2.getRight(), workspaceMgr, ArrayType.BP_WORKING_MEM); + INDArray g2Right = permute ? g2.getRight().permute(0, 2, 1): g2.getRight(); + INDArray g2Reversed = TimeSeriesUtils.reverseTimeSeries(g2Right, workspaceMgr, ArrayType.BP_WORKING_MEM); + g2Reversed = permute? g2Reversed.permute(0, 2, 1): g2Reversed; INDArray epsOut = g1.getRight().addi(g2Reversed); return new Pair<>(g, epsOut); @@ -186,25 +199,38 @@ public class BidirectionalLayer implements RecurrentLayer { public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { INDArray out1 = fwd.activate(training, workspaceMgr); INDArray out2 = bwd.activate(training, workspaceMgr); + boolean permute = getRNNDataFormat() == RNNFormat.NWC && out1.rank() == 3; + if(permute){ + out1 = out1.permute(0, 2, 1); + out2 = out2.permute(0, 2, 1); + } //Reverse the output time series. Note: when using LastTimeStepLayer, output can be rank 2 out2 = out2.rank() == 2 ? out2 : TimeSeriesUtils.reverseTimeSeries(out2, workspaceMgr, ArrayType.FF_WORKING_MEM); - + INDArray ret; switch (layerConf.getMode()){ case ADD: - return out1.addi(out2); + ret = out1.addi(out2); + break; case MUL: //TODO may be more efficient ways than this... this.outFwd = out1.detach(); this.outBwd = out2.detach(); - return workspaceMgr.dup(ArrayType.ACTIVATIONS, out1).muli(out2); + ret = workspaceMgr.dup(ArrayType.ACTIVATIONS, out1).muli(out2); + break; case AVERAGE: - return out1.addi(out2).muli(0.5); + ret = out1.addi(out2).muli(0.5); + break; case CONCAT: - INDArray ret = Nd4j.concat(1, out1, out2); - return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret); + ret = Nd4j.concat(1, out1, out2); + ret = workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret); + break; default: throw new RuntimeException("Unknown mode: " + layerConf.getMode()); } + if (permute){ + ret = ret.permute(0, 2, 1); + } + return ret; } @Override @@ -465,7 +491,9 @@ public class BidirectionalLayer implements RecurrentLayer { public void setInput(INDArray input, LayerWorkspaceMgr layerWorkspaceMgr) { this.input = input; fwd.setInput(input, layerWorkspaceMgr); - + if (getRNNDataFormat() == RNNFormat.NWC){ + input = input.permute(0, 2, 1); + } INDArray reversed; if(!input.isAttached()){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { @@ -478,6 +506,9 @@ public class BidirectionalLayer implements RecurrentLayer { reversed = TimeSeriesUtils.reverseTimeSeries(input); } } + if (getRNNDataFormat() == RNNFormat.NWC){ + reversed = reversed.permute(0, 2, 1); + } bwd.setInput(reversed, layerWorkspaceMgr); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java index e0fd80842..dc155bff3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java @@ -88,12 +88,12 @@ public class GravesBidirectionalLSTM } final FwdPassReturn fwdPass = activateHelperDirectional(true, null, null, true, true, workspaceMgr); - + fwdPass.fwdPassOutput = permuteIfNWC(fwdPass.fwdPassOutput); final Pair forwardsGradient = LSTMHelpers.backpropGradientHelper(this, this.conf, - this.layerConf().getGateActivationFn(), this.input, + this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), - getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), epsilon, + getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS, @@ -106,16 +106,17 @@ public class GravesBidirectionalLSTM final Pair backwardsGradient = LSTMHelpers.backpropGradientHelper(this, this.conf, - this.layerConf().getGateActivationFn(), this.input, + this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS), - getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), epsilon, + getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, backPass, false, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS, GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS, gradientViews, maskArray, true, null, workspaceMgr, layerConf().isHelperAllowFallback()); - + forwardsGradient.setSecond(permuteIfNWC(forwardsGradient.getSecond())); + backwardsGradient.setSecond(permuteIfNWC(backwardsGradient.getSecond())); //merge the gradient, which is key value pair of String,INDArray //the keys for forwards and backwards should be different @@ -171,7 +172,7 @@ public class GravesBidirectionalLSTM } else { forwardsEval = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), - this.input, getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), + permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS), training, null, null, forBackprop || (cacheMode != CacheMode.NONE && training), true, @@ -179,7 +180,7 @@ public class GravesBidirectionalLSTM forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback()); backwardsEval = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), - this.input, + permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS), training, null, null, @@ -187,6 +188,8 @@ public class GravesBidirectionalLSTM GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, maskArray, true, null, forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback()); + forwardsEval.fwdPassOutput = permuteIfNWC(forwardsEval.fwdPassOutput); + backwardsEval.fwdPassOutput = permuteIfNWC(backwardsEval.fwdPassOutput); cachedPassForward = forwardsEval; cachedPassBackward = backwardsEval; } @@ -228,10 +231,12 @@ public class GravesBidirectionalLSTM biasKey = GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS; } - return LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), this.input, + FwdPassReturn ret = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), getParam(recurrentKey), getParam(inputKey), getParam(biasKey), training, prevOutputActivations, prevMemCellState, forBackprop, forwards, inputKey, maskArray, true, null, forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback()); + ret.fwdPassOutput = permuteIfNWC(ret.fwdPassOutput); + return ret; } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java index b112672f9..551d4ff67 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.GravesLSTMParamInitializer; import org.nd4j.base.Preconditions; @@ -89,17 +90,17 @@ public class GravesLSTM extends BaseRecurrentLayer p = LSTMHelpers.backpropGradientHelper(this, - this.conf, this.layerConf().getGateActivationFn(), this.input, - recurrentWeights, inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, true, + this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), + recurrentWeights, inputWeights, permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true, GravesLSTMParamInitializer.INPUT_WEIGHT_KEY, GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY, GravesLSTMParamInitializer.BIAS_KEY, gradientViews, maskArray, true, null, workspaceMgr, layerConf().isHelperAllowFallback()); weightNoiseParams.clear(); - p.setSecond(backpropDropOutIfPresent(p.getSecond())); + p.setSecond(permuteIfNWC(backpropDropOutIfPresent(p.getSecond()))); return p; } @@ -117,8 +118,8 @@ public class GravesLSTM extends BaseRecurrentLayer p = LSTMHelpers.backpropGradientHelper(this, - this.conf, this.layerConf().getGateActivationFn(), this.input, - recurrentWeights, inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, true, + this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), + recurrentWeights, inputWeights, permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true, LSTMParamInitializer.INPUT_WEIGHT_KEY, LSTMParamInitializer.RECURRENT_WEIGHT_KEY, LSTMParamInitializer.BIAS_KEY, gradientViews, null, false, helper, workspaceMgr, layerConf().isHelperAllowFallback()); weightNoiseParams.clear(); - p.setSecond(backpropDropOutIfPresent(p.getSecond())); + p.setSecond(permuteIfNWC(backpropDropOutIfPresent(p.getSecond()))); return p; } @@ -167,17 +166,18 @@ public class LSTM extends BaseRecurrentLayer need to zero out these errors to avoid using errors from a masked time step // to calculate the parameter gradients. Mask array has shape [minibatch, timeSeriesLength] -> get column(this time step) timeStepMaskColumn = maskArray.getColumn(time, true); - deltaifogNext.muliColumnVector(timeStepMaskColumn); + deltaifogNext.muli(timeStepMaskColumn); //Later, the deltaifogNext is used to calculate: input weight gradients, recurrent weight gradients, bias gradients } @@ -737,7 +736,7 @@ public class LSTMHelpers { if (maskArray != null) { //Mask array is present: bidirectional RNN -> need to zero out these errors to avoid sending anything // but 0s to the layer below at this time step (for the given example) - epsilonNextSlice.muliColumnVector(timeStepMaskColumn); + epsilonNextSlice.muli(timeStepMaskColumn); } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java index 8e9f7c8f1..de85ea7b6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.recurrent; import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.util.TimeSeriesUtils; @@ -59,18 +60,41 @@ public class LastTimeStepLayer extends BaseWrapperLayer { @Override public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { - INDArray newEps = Nd4j.create(epsilon.dataType(), origOutputShape, 'f'); + long[] newEpsShape = origOutputShape; + boolean nwc = (underlying instanceof BaseRecurrentLayer && + ((BaseRecurrentLayer) underlying).getDataFormat() == RNNFormat.NWC)|| + (underlying instanceof MaskZeroLayer && ((MaskZeroLayer)underlying).getUnderlying() instanceof + BaseRecurrentLayer && ((BaseRecurrentLayer)((MaskZeroLayer)underlying).getUnderlying()).getDataFormat() + == RNNFormat.NWC); + INDArray newEps = Nd4j.create(epsilon.dataType(), newEpsShape, 'f'); if(lastTimeStepIdxs == null){ //no mask case - newEps.put(new INDArrayIndex[]{all(), all(), point(origOutputShape[2]-1)}, epsilon); - } else { - INDArrayIndex[] arr = new INDArrayIndex[]{null, all(), null}; - //TODO probably possible to optimize this with reshape + scatter ops... - for( int i=0; i p = TimeSeriesUtils.pullLastTimeSteps(in, mask, workspaceMgr, arrayType); lastTimeStepIdxs = p.getSecond(); + return p.getFirst(); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java index 4e01ea084..a4f53fa7f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java @@ -30,6 +30,9 @@ import org.nd4j.linalg.primitives.Pair; import lombok.NonNull; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import static org.deeplearning4j.nn.conf.RNNFormat.NCW; +import static org.deeplearning4j.nn.conf.RNNFormat.NWC; + /** * Masks timesteps with activation equal to the specified masking value, defaulting to 0.0. * Assumes that the input shape is [batch_size, input_size, timesteps]. @@ -76,7 +79,11 @@ public class MaskZeroLayer extends BaseWrapperLayer { throw new IllegalArgumentException("Expected input of shape [batch_size, timestep_input_size, timestep], " + "got shape "+Arrays.toString(input.shape()) + " instead"); } - INDArray mask = input.eq(maskingValue).castTo(input.dataType()).sum(1).neq(input.shape()[1]); + if ((underlying instanceof BaseRecurrentLayer && + ((BaseRecurrentLayer)underlying).getDataFormat() == NWC)){ + input = input.permute(0, 2, 1); + } + INDArray mask = input.eq(maskingValue).castTo(input.dataType()).sum(1).neq(input.shape()[1]).castTo(input.dataType()); underlying.setMaskArray(mask.detach()); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java index dd1b03d63..4d06ae755 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java @@ -22,6 +22,7 @@ import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseLayer; @@ -60,6 +61,8 @@ public class RnnLossLayer extends BaseLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); + INDArray input = this.input; + INDArray labels = this.labels; if (input.rank() != 3) throw new UnsupportedOperationException( "Input is not rank 3. Expected rank 3 input of shape [minibatch, size, sequenceLength]. Got input with rank " + @@ -67,6 +70,10 @@ public class RnnLossLayer extends BaseLayer(gradient, delta3d); @@ -159,13 +168,21 @@ public class RnnLossLayer extends BaseLayer { public static final String STATE_KEY_PREV_ACTIVATION = "prevAct"; + public SimpleRnn(NeuralNetConfiguration conf, DataType dataType) { super(conf, dataType); } @@ -92,6 +94,7 @@ public class SimpleRnn extends BaseRecurrentLayer p = activateHelper(null, true, true, workspaceMgr); @@ -125,8 +128,9 @@ public class SimpleRnn extends BaseRecurrentLayer= end; i--){ - INDArray dldaCurrent = epsilon.get(all(), all(), point(i)); + INDArray dldaCurrent = epsilon.get(all(), all(), point(i)).dup(); INDArray aCurrent = p.getFirst().get(all(), all(), point(i)); INDArray zCurrent = p.getSecond().get(all(), all(), point(i)); INDArray nCurrent = (hasLayerNorm() ? p.getThird().get(all(), all(), point(i)) : null); @@ -141,7 +145,7 @@ public class SimpleRnn extends BaseRecurrentLayer(grad, epsOut); } @@ -224,6 +229,7 @@ public class SimpleRnn extends BaseRecurrentLayer(out, outZ, outPreNorm, recPreNorm); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java index 727d19eae..93eb8b9fc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java @@ -2,6 +2,7 @@ package org.deeplearning4j.nn.layers.recurrent; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.workspace.ArrayType; @@ -22,11 +23,11 @@ import org.nd4j.linalg.util.ArrayUtil; */ public class TimeDistributedLayer extends BaseWrapperLayer { - private final int timeAxis; + private RNNFormat rnnDataFormat; - public TimeDistributedLayer(Layer underlying, int timeAxis) { + public TimeDistributedLayer(Layer underlying, RNNFormat rnnDataFormat) { super(underlying); - this.timeAxis = timeAxis; + this.rnnDataFormat = rnnDataFormat; } @@ -56,7 +57,7 @@ public class TimeDistributedLayer extends BaseWrapperLayer { protected INDArray reshape(INDArray array){ //Reshape the time axis to the minibatch axis //For example, for RNN -> FF (dense time distributed): [mb, size, seqLen] -> [mb x seqLen, size] - int axis = timeAxis; + int axis = (rnnDataFormat == RNNFormat.NCW) ? 2 : 1; if(axis < 0) axis += array.rank(); @@ -91,7 +92,7 @@ public class TimeDistributedLayer extends BaseWrapperLayer { protected INDArray revertReshape(INDArray toRevert, long minibatch){ - int axis = timeAxis; + int axis = (rnnDataFormat == RNNFormat.NCW)? 2 : 1; if(axis < 0) axis += (toRevert.rank()+1); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java index 80383698b..2a58387fe 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java @@ -17,6 +17,13 @@ package org.deeplearning4j.util; import lombok.val; +import org.deeplearning4j.nn.conf.RNNFormat; +import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; +import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; +import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; +import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; +import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; @@ -233,6 +240,12 @@ public class TimeSeriesUtils { return outReshape.reshape('f', in.size(0), in.size(1), in.size(2)); } + public static INDArray reverseTimeSeries(INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType, RNNFormat dataFormat){ + if (dataFormat == RNNFormat.NCW){ + return reverseTimeSeries(in, workspaceMgr, arrayType); + } + return reverseTimeSeries(in.permute(0, 2, 1), workspaceMgr, arrayType).permute(0, 2, 1); + } /** * Reverse an input time series along the time dimension * @@ -423,4 +436,25 @@ public class TimeSeriesUtils { return new Pair<>(workspaceMgr.leverageTo(arrayType, out), fwdPassTimeSteps); } + + /** + * Get the {@link RNNFormat} from the RNN layer, accounting for the presence of wrapper layers like Bidirectional, + * LastTimeStep, etc + * @param layer Layer to get the RNNFormat from + */ + public static RNNFormat getFormatFromRnnLayer(Layer layer){ + if(layer instanceof BaseRecurrentLayer){ + return ((BaseRecurrentLayer) layer).getRnnDataFormat(); + } else if(layer instanceof MaskZeroLayer){ + return getFormatFromRnnLayer(((MaskZeroLayer) layer).getUnderlying()); + } else if(layer instanceof Bidirectional){ + return getFormatFromRnnLayer(((Bidirectional) layer).getFwd()); + } else if(layer instanceof LastTimeStep){ + return getFormatFromRnnLayer(((LastTimeStep) layer).getUnderlying()); + } else if(layer instanceof TimeDistributed){ + return ((TimeDistributed) layer).getRnnDataFormat(); + } else { + throw new IllegalStateException("Unable to get RNNFormat from layer of type: " + layer); + } + } }