DL4J NWC support for RNNs (#379)

* merge conf

* merge conf

* conf fix

* NWC initial

* revert pom.xml

* revert pom.xml

* default NCW

* bidirectional+some tests

* RNNOutputLayer, RNNLossLayer, Graves + tests

* rnn tests

* LastTimeStep + tests

* masking + tests

* graves, rnnoutput, rnnloss

* nwc timeseries reverse

* more tests

* bi-gravelstm test

* fixes

* rnn df tests basic

* bug fix: cudnn fallback

* bug fix

* misc

* gravelstm tests

* preprocessor fixes

* TimeDistributed

* more tests

* RnnLossLayer builder def val

* copyright headers

* Remove debug println

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small fix + test naming

Signed-off-by: Alex Black <blacka101@gmail.com>

* Parameterized test name

Signed-off-by: Alex Black <blacka101@gmail.com>

* fix LastTimeStep masked

* Fix MaskZero mask datatype issue

Signed-off-by: Alex Black <blacka101@gmail.com>

* rem println

* javadoc

* Fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Alex Black <blacka101@gmail.com>
master
Fariz Rahman 2020-04-23 06:16:44 +04:00 committed by GitHub
parent 032b97912e
commit 2ecabde500
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 1256 additions and 278 deletions

View File

@ -338,7 +338,7 @@ public class RnnGradientChecks extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(layerSize).build())
.layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build(), 2))
.layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build()))
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn))

View File

@ -819,7 +819,7 @@ public class DTypeTests extends BaseDL4JTest {
.layer(new DenseLayer.Builder().nOut(5).build())
.layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())
.layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()))
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build(), 2))
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build()))
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build())
.layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build())
.layer(secondLast)

View File

@ -24,10 +24,7 @@ import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
@ -45,6 +42,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -61,12 +60,22 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import static org.deeplearning4j.nn.conf.RNNFormat.NCW;
import static org.junit.Assert.assertEquals;
@Slf4j
@RunWith(Parameterized.class)
public class BidirectionalTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public BidirectionalTest(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void compareImplementations(){
for(WorkspaceMode wsm : WorkspaceMode.values()) {
@ -82,9 +91,9 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm)
.updater(new Adam())
.list()
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
.nIn(10).nOut(10).build())
.build();
@ -95,9 +104,9 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm)
.updater(new Adam())
.list()
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build())
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build())
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
.nIn(10).nOut(10).build())
.build();
@ -116,15 +125,24 @@ public class BidirectionalTest extends BaseDL4JTest {
net2.setParams(net1.params()); //Assuming exact same layout here...
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
INDArray in;
if (rnnDataFormat == NCW){
in = Nd4j.rand(new int[]{3, 10, 5});
}else{
in = Nd4j.rand(new int[]{3, 5, 10});
}
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
assertEquals(out1, out2);
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
INDArray labels;
if (rnnDataFormat == NCW){
labels = Nd4j.rand(new int[]{3, 10, 5});
}else{
labels = Nd4j.rand(new int[]{3, 5, 10});
}
net1.setInput(in);
net1.setLabels(labels);
@ -276,17 +294,22 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm)
.updater(new Adam())
.list()
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
.nIn(10).nOut(10).build())
.nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init();
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
INDArray in;
INDArray labels;
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 5} : new long[]{3, 5, 10};
in = Nd4j.rand(inshape);
labels = Nd4j.rand(inshape);
net1.fit(in, labels);
@ -300,8 +323,8 @@ public class BidirectionalTest extends BaseDL4JTest {
MultiLayerNetwork net2 = ModelSerializer.restoreMultiLayerNetwork(new ByteArrayInputStream(bytes), true);
in = Nd4j.rand(new int[]{3, 10, 5});
labels = Nd4j.rand(new int[]{3, 10, 5});
in = Nd4j.rand(inshape);
labels = Nd4j.rand(inshape);
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
@ -338,18 +361,18 @@ public class BidirectionalTest extends BaseDL4JTest {
.updater(new Adam())
.graphBuilder()
.addInputs("in")
.layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "in")
.layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "0")
.layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
.layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in")
.layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "0")
.layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
.nIn(10).nOut(10).build(), "1")
.setOutputs("2")
.build();
ComputationGraph net1 = new ComputationGraph(conf1);
net1.init();
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
long[] inshape = (rnnDataFormat == NCW)? new long[]{3, 10, 5}: new long[]{3, 5, 10};
INDArray in = Nd4j.rand(inshape);
INDArray labels = Nd4j.rand(inshape);
net1.fit(new DataSet(in, labels));
@ -363,8 +386,8 @@ public class BidirectionalTest extends BaseDL4JTest {
ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true);
in = Nd4j.rand(new int[]{3, 10, 5});
labels = Nd4j.rand(new int[]{3, 10, 5});
in = Nd4j.rand(inshape);
labels = Nd4j.rand(inshape);
INDArray out1 = net1.outputSingle(in);
INDArray out2 = net2.outputSingle(in);
@ -394,8 +417,8 @@ public class BidirectionalTest extends BaseDL4JTest {
Bidirectional.Mode[] modes = new Bidirectional.Mode[]{Bidirectional.Mode.CONCAT, Bidirectional.Mode.ADD,
Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL};
INDArray in = Nd4j.rand(new int[]{3, 10, 6});
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10};
INDArray in = Nd4j.rand(inshape);
for (Bidirectional.Mode m : modes) {
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
@ -406,7 +429,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm)
.updater(new Adam())
.list()
.layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build()))
.layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
@ -418,7 +441,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER)
.updater(new Adam())
.list()
.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build())
.layer(new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2.clone());
@ -434,11 +457,10 @@ public class BidirectionalTest extends BaseDL4JTest {
net3.setParam("0_RW", net1.getParam("0_bRW"));
net3.setParam("0_b", net1.getParam("0_bb"));
INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat);
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat);
INDArray outExp;
switch (m) {
@ -452,7 +474,7 @@ public class BidirectionalTest extends BaseDL4JTest {
outExp = out2.add(out3).muli(0.5);
break;
case CONCAT:
outExp = Nd4j.concat(1, out2, out3);
outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3);
break;
default:
throw new RuntimeException();
@ -464,25 +486,25 @@ public class BidirectionalTest extends BaseDL4JTest {
//Check gradients:
if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) {
INDArray eps = Nd4j.rand(new int[]{3, 10, 6});
INDArray eps = Nd4j.rand(inshape);
INDArray eps1;
if (m == Bidirectional.Mode.CONCAT) {
eps1 = Nd4j.concat(1, eps, eps);
eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps);
} else {
eps1 = eps;
}
net1.setInput(in);
net2.setInput(in);
net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat));
net1.feedForward(true, false);
net2.feedForward(true, false);
net3.feedForward(true, false);
Pair<Gradient, INDArray> p1 = net1.backpropGradient(eps1, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient, INDArray> p2 = net2.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient, INDArray> p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT), LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient, INDArray> p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat), LayerWorkspaceMgr.noWorkspaces());
Gradient g1 = p1.getFirst();
Gradient g2 = p2.getFirst();
Gradient g3 = p3.getFirst();
@ -520,7 +542,9 @@ public class BidirectionalTest extends BaseDL4JTest {
Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL};
INDArray in = Nd4j.rand(new int[]{3, 10, 6});
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10};
INDArray in = Nd4j.rand(inshape);
for (Bidirectional.Mode m : modes) {
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
@ -532,7 +556,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.updater(new Adam())
.graphBuilder()
.addInputs("in")
.layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build()), "in")
.layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in")
.setOutputs("0")
.build();
@ -546,7 +570,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.updater(new Adam())
.graphBuilder()
.addInputs("in")
.layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "in")
.layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build(), "in")
.setOutputs("0")
.build();
@ -566,9 +590,20 @@ public class BidirectionalTest extends BaseDL4JTest {
INDArray out1 = net1.outputSingle(in);
INDArray out2 = net2.outputSingle(in);
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.outputSingle(
TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)),
LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
INDArray out3;
INDArray inReverse;
if (rnnDataFormat == RNNFormat.NWC){
inReverse = TimeSeriesUtils.reverseTimeSeries(in.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1);
out3 = net3.outputSingle(inReverse);
out3 = TimeSeriesUtils.reverseTimeSeries(out3.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1);
}
else{
inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
out3 = net3.outputSingle(inReverse);
out3 = TimeSeriesUtils.reverseTimeSeries(out3, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
}
INDArray outExp;
switch (m) {
@ -582,7 +617,9 @@ public class BidirectionalTest extends BaseDL4JTest {
outExp = out2.add(out3).muli(0.5);
break;
case CONCAT:
outExp = Nd4j.concat(1, out2, out3);
System.out.println(out2.shapeInfoToString());
System.out.println(out3.shapeInfoToString());
outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3);
break;
default:
throw new RuntimeException();
@ -594,22 +631,26 @@ public class BidirectionalTest extends BaseDL4JTest {
//Check gradients:
if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) {
INDArray eps = Nd4j.rand(new int[]{3, 10, 6});
INDArray eps = Nd4j.rand(inshape);
INDArray eps1;
if (m == Bidirectional.Mode.CONCAT) {
eps1 = Nd4j.concat(1, eps, eps);
eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps);
} else {
eps1 = eps;
}
INDArray epsReversed = (rnnDataFormat == NCW)?
TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT):
TimeSeriesUtils.reverseTimeSeries(eps.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)
.permute(0, 2, 1);
net1.outputSingle(true, false, in);
net2.outputSingle(true, false, in);
net3.outputSingle(true, false, TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
net3.outputSingle(true, false, inReverse);
Gradient g1 = net1.backpropGradient(eps1);
Gradient g2 = net2.backpropGradient(eps);
Gradient g3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
Gradient g3 = net3.backpropGradient(epsReversed);
for (boolean updates : new boolean[]{false, true}) {
if (updates) {

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
@ -31,6 +32,8 @@ import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -42,10 +45,18 @@ import org.nd4j.linalg.primitives.Pair;
import static org.junit.Assert.*;
@RunWith(Parameterized.class)
public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
private double score = 0.0;
private RNNFormat rnnDataFormat;
public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testBidirectionalLSTMGravesForwardBasic() {
//Very basic test of forward prop. of LSTM layer with a time series.
@ -55,7 +66,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
.nOut(nHiddenUnits).activation(Activation.TANH).build())
.nOut(nHiddenUnits).dataFormat(rnnDataFormat).activation(Activation.TANH).build())
.build();
val numParams = conf.getLayer().initializer().numParams(conf);
@ -65,7 +76,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
//Data: has shape [miniBatchSize,nIn,timeSeriesLength];
//Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength];
if (rnnDataFormat == RNNFormat.NCW){
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1);
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1});
@ -82,6 +93,25 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15});
}
else{
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, 1, nIn);
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations1.shape(), new long[] {1, 1, nHiddenUnits});
final INDArray dataMultiExampleLength1 = Nd4j.ones(10, 1, nIn);
final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations2.shape(), new long[] {10, 1, nHiddenUnits});
final INDArray dataSingleExampleLength12 = Nd4j.ones(1, 12, nIn);
final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations3.shape(), new long[] {1, 12, nHiddenUnits});
final INDArray dataMultiExampleLength15 = Nd4j.ones(10, 15, nIn);
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations4.shape(), new long[] {10, 15, nHiddenUnits});
}
}
@Test
public void testBidirectionalLSTMGravesBackwardBasic() {
@ -94,14 +124,15 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
testGravesBackwardBasicHelper(13, 3, 17, 1, 1); //Edge case: both miniBatchSize = 1 and timeSeriesLength = 1
}
private static void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize,
private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize,
int timeSeriesLength) {
INDArray inputData = Nd4j.ones(miniBatchSize, nIn, timeSeriesLength);
INDArray inputData = (rnnDataFormat == RNNFormat.NCW)?Nd4j.ones(miniBatchSize, nIn, timeSeriesLength):
Nd4j.ones(miniBatchSize, timeSeriesLength, nIn);
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
.nOut(lstmNHiddenUnits)
.nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat)
.dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build())
.build();
@ -114,7 +145,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces());
assertNotNull(lstm.input());
INDArray epsilon = Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength);
INDArray epsilon =(rnnDataFormat == RNNFormat.NCW)? Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength):
Nd4j.ones(miniBatchSize, timeSeriesLength, lstmNHiddenUnits);
Pair<Gradient, INDArray> out = lstm.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces());
Gradient outGradient = out.getFirst();
@ -147,7 +179,11 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
assertArrayEquals(recurrentWeightGradientB.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3});
assertNotNull(nextEpsilon);
if (rnnDataFormat == RNNFormat.NCW) {
assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, nIn, timeSeriesLength});
}else{
assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, timeSeriesLength, nIn });
}
//Check update:
for (String s : outGradient.gradientForVariable().keySet()) {
@ -226,7 +262,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
.nOut(layerSize)
.nOut(layerSize).dataFormat(rnnDataFormat)
.dist(new UniformDistribution(-0.1, 0.1)).activation(Activation.TANH).build())
.build();
@ -237,7 +273,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
.instantiate(confBidirectional, null, 0, params, true, params.dataType());
final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength});
final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}):
Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn});
final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces());
@ -265,13 +302,13 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final NeuralNetConfiguration confBidirectional =
new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder()
.nIn(nIn).nOut(layerSize)
.nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
.dist(new UniformDistribution(-0.1, 0.1))
.activation(Activation.TANH).updater(new NoOp()).build())
.build();
final NeuralNetConfiguration confForwards = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize)
.layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
.weightInit(WeightInit.ZERO).activation(Activation.TANH).build())
.build();
@ -290,9 +327,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
Nd4j.create(1, confForwards.getLayer().initializer().numParams(confForwards)));
final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength});
final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}):
Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn});
final INDArray sigb = sig.dup();
if (rnnDataFormat == RNNFormat.NCW) {
reverseColumnsInPlace(sigb.slice(0));
}
else{
reverseColumnsInPlace(sigb.slice(0).permute(1, 0));
}
final INDArray recurrentWeightsF = bidirectionalLSTM
.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS);
@ -345,10 +389,14 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
assertArrayEquals(activation1.data().asFloat(), activation2.data().asFloat(), 1e-5f);
final INDArray randSig = Nd4j.rand(new int[] {1, layerSize, timeSeriesLength});
final INDArray randSigBackwards = randSig.dup();
final INDArray randSig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {1, layerSize, timeSeriesLength}):
Nd4j.rand(new int[] {1, timeSeriesLength, layerSize});
INDArray randSigBackwards = randSig.dup();
if (rnnDataFormat == RNNFormat.NCW){
reverseColumnsInPlace(randSigBackwards.slice(0));
}else{
reverseColumnsInPlace(randSigBackwards.slice(0).permute(1, 0));
}
final Pair<Gradient, INDArray> backprop1 = forwardsLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces());
final Pair<Gradient, INDArray> backprop2 = bidirectionalLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces());
@ -399,10 +447,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final INDArray activation3 = bidirectionalLSTM.activate(sigb, false, LayerWorkspaceMgr.noWorkspaces()).slice(0);
final INDArray activation3Reverse = activation3.dup();
if (rnnDataFormat == RNNFormat.NCW){
reverseColumnsInPlace(activation3Reverse);
}
else{
reverseColumnsInPlace(activation3Reverse.permute(1, 0));
}
assertEquals(activation3Reverse, activation1);
assertArrayEquals(activation3Reverse.shape(), activation1.shape());
assertEquals(activation3Reverse, activation1);
//test backprop now
final INDArray refBackGradientReccurrent =
@ -434,7 +488,12 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final INDArray refEpsilon = backprop1.getSecond().dup();
final INDArray backEpsilon = backprop3.getSecond().dup();
if (rnnDataFormat == RNNFormat.NCW) {
reverseColumnsInPlace(refEpsilon.slice(0));
}
else{
reverseColumnsInPlace(refEpsilon.slice(0).permute(1, 0));
}
assertArrayEquals(backEpsilon.dup().data().asDouble(), refEpsilon.dup().data().asDouble(), 1e-6);
}
@ -477,10 +536,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.seed(12345).list()
.layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder()
.gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2)
.gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat)
.build())
.layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2)
.lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat)
.activation(Activation.TANH).build())
.build();
@ -492,7 +551,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
INDArray in = Nd4j.rand(new int[] {3, 2, 5});
INDArray labels = Nd4j.rand(new int[] {3, 2, 5});
if (rnnDataFormat == RNNFormat.NWC){
in = in.permute(0, 2, 1);
labels = labels.permute(0, 2, 1);
}
net.fit(in, labels);
}
}

View File

@ -21,11 +21,14 @@ import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -36,9 +39,17 @@ import java.util.Collections;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
public class MaskZeroLayerTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public MaskZeroLayerTest(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void activate() {
@ -57,7 +68,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
.activation(Activation.IDENTITY)
.gateActivationFunction(Activation.IDENTITY)
.nIn(2)
.nOut(1)
.nOut(1).dataFormat(rnnDataFormat)
.build();
NeuralNetConfiguration conf = new NeuralNetConfiguration();
conf.setLayer(underlying);
@ -72,9 +83,14 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
MaskZeroLayer l = new MaskZeroLayer(lstm, maskingValue);
INDArray input = Nd4j.create(Arrays.asList(ex1, ex2), new int[]{2, 2, 3});
if (rnnDataFormat == RNNFormat.NWC){
input = input.permute(0, 2, 1);
}
//WHEN
INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces());
if (rnnDataFormat == RNNFormat.NWC){
out = out.permute(0, 2,1);
}
//THEN output should only be incremented for the non-zero timesteps
INDArray firstExampleOutput = out.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all());
INDArray secondExampleOutput = out.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all());
@ -94,7 +110,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder()
.setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).build()).build())
.setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

View File

@ -0,0 +1,394 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.layers.recurrent;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
@AllArgsConstructor
public class RnnDataFormatTests extends BaseDL4JTest {
private boolean helpers;
private boolean lastTimeStep;
private boolean maskZeros;
@Parameterized.Parameters(name = "helpers={0},lastTimeStep={1},maskZero={2}")
public static List params(){
List<Object[]> ret = new ArrayList<>();
for (boolean helpers: new boolean[]{true, false})
for (boolean lastTimeStep: new boolean[]{true, false})
for (boolean maskZero: new boolean[]{true, false})
ret.add(new Object[]{helpers, lastTimeStep, maskZero});
return ret;
}
@Test
public void testSimpleRnn() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSimpleRnnNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getSimpleRnnNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getSimpleRnnNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getSimpleRnnNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testLSTM() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testGraveLSTM() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGravesLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getGravesLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getGravesLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getGravesLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testGraveBiLSTM() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGravesBidirectionalLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getGravesBidirectionalLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getGravesBidirectionalLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getGravesBidirectionalLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
private MultiLayerNetwork getGravesBidirectionalLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getGravesLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new GravesLSTM.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new GravesLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new LSTM.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new LSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getSimpleRnnNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new SimpleRnn.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new SimpleRnn.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getNetWithLayer(Layer layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) {
if (maskZeros){
layer = new MaskZeroLayer.Builder().setMaskValue(0.).setUnderlying(layer).build();
}
if(lastTimeStep){
layer = new LastTimeStep(layer);
}
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.seed(12345)
.list()
.layer(new LSTM.Builder()
.nIn(3)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build())
.layer(layer)
.layer(
(lastTimeStep)?new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build():
new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).dataFormat(format).build()
)
.setInputType(InputType.recurrent(3, 12, format));
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init();
return net;
}
@AllArgsConstructor
@Data
@NoArgsConstructor
@Builder
private static class TestCase {
private String msg;
private MultiLayerNetwork net1;
private MultiLayerNetwork net2;
private MultiLayerNetwork net3;
private MultiLayerNetwork net4;
private INDArray inNCW;
private INDArray labelsNCW;
private INDArray labelsNWC;
private int testLayerIdx;
private boolean nwcOutput;
public static void testHelper(TestCase tc) {
tc.net2.params().assign(tc.net1.params());
tc.net3.params().assign(tc.net1.params());
tc.net4.params().assign(tc.net1.params());
INDArray inNCW = tc.inNCW;
INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup();
INDArray l0_1 = tc.net1.feedForward(inNCW).get(tc.testLayerIdx + 1);
INDArray l0_2 = tc.net2.feedForward(inNCW).get(tc.testLayerIdx + 1);
INDArray l0_3 = tc.net3.feedForward(inNWC).get(tc.testLayerIdx + 1);
INDArray l0_4 = tc.net4.feedForward(inNWC).get(tc.testLayerIdx + 1);
boolean rank3Out = tc.labelsNCW.rank() == 3;
assertEquals(tc.msg, l0_1, l0_2);
if (rank3Out){
assertEquals(tc.msg, l0_1, l0_3.permute(0, 2, 1));
assertEquals(tc.msg, l0_1, l0_4.permute(0, 2, 1));
}
else{
assertEquals(tc.msg, l0_1, l0_3);
assertEquals(tc.msg, l0_1, l0_4);
}
INDArray out1 = tc.net1.output(inNCW);
INDArray out2 = tc.net2.output(inNCW);
INDArray out3 = tc.net3.output(inNWC);
INDArray out4 = tc.net4.output(inNWC);
assertEquals(tc.msg, out1, out2);
if (rank3Out){
assertEquals(tc.msg, out1, out3.permute(0, 2, 1)); //NWC to NCW
assertEquals(tc.msg, out1, out4.permute(0, 2, 1));
}
else{
assertEquals(tc.msg, out1, out3); //NWC to NCW
assertEquals(tc.msg, out1, out4);
}
//Test backprop
Pair<Gradient, INDArray> p1 = tc.net1.calculateGradients(inNCW, tc.labelsNCW, null, null);
Pair<Gradient, INDArray> p2 = tc.net2.calculateGradients(inNCW, tc.labelsNCW, null, null);
Pair<Gradient, INDArray> p3 = tc.net3.calculateGradients(inNWC, tc.labelsNWC, null, null);
Pair<Gradient, INDArray> p4 = tc.net4.calculateGradients(inNWC, tc.labelsNWC, null, null);
//Inpput gradients
assertEquals(tc.msg, p1.getSecond(), p2.getSecond());
assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0, 2, 1)); //Input gradients for NWC input are also in NWC format
assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0, 2, 1));
List<String> diff12 = differentGrads(p1.getFirst(), p2.getFirst());
List<String> diff13 = differentGrads(p1.getFirst(), p3.getFirst());
List<String> diff14 = differentGrads(p1.getFirst(), p4.getFirst());
assertEquals(tc.msg + " " + diff12, 0, diff12.size());
assertEquals(tc.msg + " " + diff13, 0, diff13.size());
assertEquals(tc.msg + " " + diff14, 0, diff14.size());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable());
tc.net1.fit(inNCW, tc.labelsNCW);
tc.net2.fit(inNCW, tc.labelsNCW);
tc.net3.fit(inNWC, tc.labelsNWC);
tc.net4.fit(inNWC, tc.labelsNWC);
assertEquals(tc.msg, tc.net1.params(), tc.net2.params());
assertEquals(tc.msg, tc.net1.params(), tc.net3.params());
assertEquals(tc.msg, tc.net1.params(), tc.net4.params());
//Test serialization
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2);
MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3);
MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4);
out1 = tc.net1.output(inNCW);
assertEquals(tc.msg, out1, net1a.output(inNCW));
assertEquals(tc.msg, out1, net2a.output(inNCW));
if (rank3Out) {
assertEquals(tc.msg, out1, net3a.output(inNWC).permute(0, 2, 1)); //NWC to NCW
assertEquals(tc.msg, out1, net4a.output(inNWC).permute(0, 2, 1));
}
else{
assertEquals(tc.msg, out1, net3a.output(inNWC)); //NWC to NCW
assertEquals(tc.msg, out1, net4a.output(inNWC));
}
}
}
private static List<String> differentGrads(Gradient g1, Gradient g2){
List<String> differs = new ArrayList<>();
Map<String,INDArray> m1 = g1.gradientForVariable();
Map<String,INDArray> m2 = g2.gradientForVariable();
for(String s : m1.keySet()){
INDArray a1 = m1.get(s);
INDArray a2 = m2.get(s);
if(!a1.equals(a2)){
differs.add(s);
}
}
return differs;
}
}

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
@ -29,6 +30,8 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
@ -42,14 +45,25 @@ import static org.nd4j.linalg.activations.Activation.IDENTITY;
import static org.nd4j.linalg.activations.Activation.TANH;
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE;
@RunWith(Parameterized.class)
public class TestLastTimeStepLayer extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestLastTimeStepLayer(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters(name="{0}")
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testLastTimeStepVertex() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder()
.nIn(5).nOut(6).build()), "in")
.nIn(5).nOut(6).dataFormat(rnnDataFormat).build()), "in")
.setOutputs("lastTS")
.build();
@ -59,9 +73,22 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
//First: test without input mask array
Nd4j.getRandom().setSeed(12345);
Layer l = graph.getLayer("lastTS");
INDArray in = Nd4j.rand(new int[]{3, 5, 6});
INDArray in;
if (rnnDataFormat == RNNFormat.NCW){
in = Nd4j.rand(3, 5, 6);
}
else{
in = Nd4j.rand(3, 6, 5);
}
INDArray outUnderlying = ((LastTimeStepLayer)l).getUnderlying().activate(in, false, LayerWorkspaceMgr.noWorkspaces());
INDArray expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5));
INDArray expOut;
if (rnnDataFormat == RNNFormat.NCW){
expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5));
}
else{
expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.point(5), NDArrayIndex.all());
}
//Forward pass:
@ -76,9 +103,17 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
graph.setLayerMaskArrays(new INDArray[]{inMask}, null);
expOut = Nd4j.zeros(3, 6);
if (rnnDataFormat == RNNFormat.NCW){
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2)));
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3)));
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4)));
}
else{
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.point(2), NDArrayIndex.all()));
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.point(3), NDArrayIndex.all()));
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.point(4), NDArrayIndex.all()));
}
outFwd = l.activate(in, false, LayerWorkspaceMgr.noWorkspaces());
assertEquals(expOut, outFwd);
@ -97,9 +132,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
.seed(1234)
.graphBuilder()
.addInputs("in")
.setInputTypes(InputType.recurrent(1))
.setInputTypes(InputType.recurrent(1, rnnDataFormat))
.addLayer("RNN", new LastTimeStep(new LSTM.Builder()
.nOut(10)
.nOut(10).dataFormat(rnnDataFormat)
.build()), "in")
.addLayer("dense", new DenseLayer.Builder()
.nOut(10)
@ -120,7 +155,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
INDArray fm2 = Nd4j.zeros(1,24);
INDArray fm3 = Nd4j.zeros(1,24);
fm3.get(NDArrayIndex.point(0), NDArrayIndex.interval(0,5)).assign(1);
if (rnnDataFormat == RNNFormat.NWC){
f = f.permute(0, 2, 1);
}
INDArray[] out1 = cg.output(false, new INDArray[]{f}, new INDArray[]{fm1});
try {
cg.output(false, new INDArray[]{f}, new INDArray[]{fm2});

View File

@ -20,6 +20,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.dropout.TestDropout;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM;
@ -31,6 +32,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -46,8 +49,18 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
@RunWith(Parameterized.class)
public class TestRnnLayers extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestRnnLayers(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testTimeStepIs3Dimensional() {
@ -58,8 +71,8 @@ public class TestRnnLayers extends BaseDL4JTest {
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.list()
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).build())
.layer(new LSTM.Builder().nIn(3).nOut(5).build())
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).dataFormat(rnnDataFormat).build())
.layer(new LSTM.Builder().nIn(3).nOut(5).dataFormat(rnnDataFormat).build())
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).build())
.build();
@ -70,9 +83,9 @@ public class TestRnnLayers extends BaseDL4JTest {
org.deeplearning4j.nn.layers.recurrent.SimpleRnn simpleRnn =
(org.deeplearning4j.nn.layers.recurrent.SimpleRnn) net.getLayer(0);
INDArray rnnInput3d = Nd4j.create(10, 12, 1);
INDArray rnnInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10,12, 1):Nd4j.create(10, 1, 12);
INDArray simpleOut = simpleRnn.rnnTimeStep(rnnInput3d, LayerWorkspaceMgr.noWorkspaces());
assertTrue(Arrays.equals(simpleOut.shape(), new long[] {10, 3, 1}));
assertTrue(Arrays.equals(simpleOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 3, 1}:new long[]{10, 1, 3}));
INDArray rnnInput2d = Nd4j.create(10, 12);
try {
@ -84,9 +97,9 @@ public class TestRnnLayers extends BaseDL4JTest {
org.deeplearning4j.nn.layers.recurrent.LSTM lstm =
(org.deeplearning4j.nn.layers.recurrent.LSTM) net.getLayer(1);
INDArray lstmInput3d = Nd4j.create(10, 3, 1);
INDArray lstmInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10, 3, 1):Nd4j.create(10, 1, 3);
INDArray lstmOut = lstm.rnnTimeStep(lstmInput3d, LayerWorkspaceMgr.noWorkspaces());
assertTrue(Arrays.equals(lstmOut.shape(), new long[] {10, 5, 1}));
assertTrue(Arrays.equals(lstmOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 5, 1}:new long[]{10, 1, 5}));
INDArray lstmInput2d = Nd4j.create(10, 3);
try {
@ -112,19 +125,19 @@ public class TestRnnLayers extends BaseDL4JTest {
TestDropout.CustomDropout cd = new TestDropout.CustomDropout();
switch (s){
case "graves":
layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
break;
case "lstm":
layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
break;
case "simple":
layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
break;
default:
throw new RuntimeException(s);
@ -134,21 +147,21 @@ public class TestRnnLayers extends BaseDL4JTest {
.seed(12345)
.list()
.layer(layer)
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerConfiguration confD = new NeuralNetConfiguration.Builder()
.seed(12345)
.list()
.layer(layerD)
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerConfiguration confD2 = new NeuralNetConfiguration.Builder()
.seed(12345)
.list()
.layer(layerD2)
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -178,7 +191,6 @@ public class TestRnnLayers extends BaseDL4JTest {
assertNotEquals(s, out2, out2D);
INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345);
net.fit(f.dup(), l);
netD.fit(f.dup(), l);
assertNotEquals(s, net.params(), netD.params());
@ -209,10 +221,10 @@ public class TestRnnLayers extends BaseDL4JTest {
switch (i){
case 0:
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).build());
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).dataFormat(rnnDataFormat).build());
break;
case 1:
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build());
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).dataFormat(rnnDataFormat).build());
break;
default:
throw new RuntimeException();
@ -224,13 +236,16 @@ public class TestRnnLayers extends BaseDL4JTest {
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10);
if (rnnDataFormat == RNNFormat.NWC){
l = l.permute(0, 2, 1);
}
try{
net.fit(in,l);
} catch (Throwable t){
String msg = t.getMessage();
if(msg == null)
t.printStackTrace();
System.out.println(i);
assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
}

View File

@ -20,10 +20,13 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -36,8 +39,18 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import static org.nd4j.linalg.indexing.NDArrayIndex.point;
@RunWith(Parameterized.class)
public class TestSimpleRnn extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestSimpleRnn(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testSimpleRnn(){
Nd4j.getRandom().setSeed(12345);
@ -46,7 +59,15 @@ public class TestSimpleRnn extends BaseDL4JTest {
int nIn = 5;
int layerSize = 6;
int tsLength = 7;
INDArray in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength});
INDArray in;
if (rnnDataFormat == RNNFormat.NCW){
in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength});
}
else{
in = Nd4j.rand(DataType.FLOAT, new int[]{m, tsLength, nIn});
}
// in.get(all(), all(), interval(1,tsLength)).assign(0);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
@ -54,7 +75,7 @@ public class TestSimpleRnn extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER)
.activation(Activation.TANH)
.list()
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).build())
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -68,7 +89,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
INDArray outLast = null;
for( int i=0; i<tsLength; i++ ){
INDArray inCurrent = in.get(all(), all(), point(i));
INDArray inCurrent;
if (rnnDataFormat == RNNFormat.NCW){
inCurrent = in.get(all(), all(), point(i));
}
else{
inCurrent = in.get(all(), point(i), all());
}
INDArray outExpCurrent = inCurrent.mmul(w);
if(outLast != null){
@ -79,7 +106,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
Transforms.tanh(outExpCurrent, false);
INDArray outActCurrent = out.get(all(), all(), point(i));
INDArray outActCurrent;
if (rnnDataFormat == RNNFormat.NCW){
outActCurrent = out.get(all(), all(), point(i));
}
else{
outActCurrent = out.get(all(), point(i), all());
}
assertEquals(String.valueOf(i), outExpCurrent, outActCurrent);
outLast = outExpCurrent;
@ -100,7 +133,7 @@ public class TestSimpleRnn extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER)
.activation(Activation.TANH)
.list()
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize)
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
.biasInit(100)
.build())
.build();

View File

@ -4,6 +4,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
@ -12,6 +13,8 @@ import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -22,8 +25,18 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
public class TestTimeDistributed extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestTimeDistributed(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testTimeDistributed(){
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
@ -34,11 +47,11 @@ public class TestTimeDistributed extends BaseDL4JTest {
.seed(12345)
.updater(new Adam(0.1))
.list()
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
.layer(new LSTM.Builder().nIn(3).nOut(3).dataFormat(rnnDataFormat).build())
.layer(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build())
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX)
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(3))
.setInputType(InputType.recurrent(3, rnnDataFormat))
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
@ -47,11 +60,11 @@ public class TestTimeDistributed extends BaseDL4JTest {
.seed(12345)
.updater(new Adam(0.1))
.list()
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), 2))
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX)
.layer(new LSTM.Builder().nIn(3).nOut(3).dataFormat(rnnDataFormat).build())
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), rnnDataFormat))
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(3))
.setInputType(InputType.recurrent(3, rnnDataFormat))
.build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
@ -62,13 +75,21 @@ public class TestTimeDistributed extends BaseDL4JTest {
for( int mb : new int[]{1, 5}) {
for(char inLabelOrder : new char[]{'c', 'f'}) {
INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3, 5).dup(inLabelOrder);
if (rnnDataFormat == RNNFormat.NWC){
in = in.permute(0, 2, 1);
}
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
assertEquals(out1, out2);
INDArray labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder);
INDArray labels ;
if (rnnDataFormat == RNNFormat.NCW) {
labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder);
}else{
labels = TestUtils.randomOneHotTimeSeries(mb, 5, 3).dup(inLabelOrder);
}
DataSet ds = new DataSet(in, labels);
net1.fit(ds);

View File

@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
@ -160,8 +161,8 @@ public class KerasConvolution1D extends KerasConvolution {
public InputPreProcessor getInputPreprocessor(InputType... inputType) throws InvalidKerasConfigurationException {
if (inputType.length > 1)
throw new InvalidKerasConfigurationException(
"Keras LSTM layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
"Keras Conv1D layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], RNNFormat.NCW,layerName);
}

View File

@ -22,11 +22,9 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
@ -37,6 +35,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -266,7 +265,8 @@ public class KerasLSTM extends KerasLayer {
throw new InvalidKerasConfigurationException("Keras LSTM layer accepts only one single input" +
"or three (input to LSTM and two states tensors, but " +
"received " + inputType.length + ".");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer);
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f,layerName);
}
/**

View File

@ -21,7 +21,9 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.Layer;
@ -36,6 +38,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
@ -227,7 +230,8 @@ public class KerasSimpleRnn extends KerasLayer {
throw new InvalidKerasConfigurationException(
"Keras SimpleRnn layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer);
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f, layerName);
}
/**

View File

@ -218,7 +218,7 @@ public class KerasBidirectional extends KerasLayer {
if (inputType.length > 1)
throw new InvalidKerasConfigurationException(
"Keras Bidirectional layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], ((Bidirectional)layer).getRNNDataFormat(), layerName);
}
/**

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.api.layers;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
@ -98,4 +99,5 @@ public interface RecurrentLayer extends Layer {
*/
Pair<Gradient, INDArray> tbpttBackpropGradient(INDArray epsilon, int tbpttBackLength, LayerWorkspaceMgr workspaceMgr);
}

View File

@ -0,0 +1,29 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.conf;
/**
* NCW = "channels first" - arrays of shape [minibatch, channels, width]<br>
* NWC = "channels last" - arrays of shape [minibatch, width, channels]<br>
* "width" corresponds to sequence length and "channels" corresponds to sequence item size.
*/
public enum RNNFormat {
NCW,
NWC
}

View File

@ -20,6 +20,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -111,9 +112,16 @@ public abstract class InputType implements Serializable {
* @return InputTypeRecurrent
*/
public static InputType recurrent(long size, long timeSeriesLength) {
return new InputTypeRecurrent(size, timeSeriesLength);
return new InputTypeRecurrent(size, timeSeriesLength, RNNFormat.NCW);
}
public static InputType recurrent(long size, RNNFormat format){
return new InputTypeRecurrent(size, format);
}
public static InputType recurrent(long size, long timeSeriesLength, RNNFormat format){
return new InputTypeRecurrent(size, timeSeriesLength, format);
}
/**
* Input type for convolutional (CNN) data, that is 4d with shape [miniBatchSize, channels, height, width].
* For CNN data that has been flattened, use {@link #convolutionalFlat(long, long, long)}
@ -216,14 +224,23 @@ public abstract class InputType implements Serializable {
public static class InputTypeRecurrent extends InputType {
private long size;
private long timeSeriesLength;
private RNNFormat format = RNNFormat.NCW;
public InputTypeRecurrent(long size) {
this(size, -1);
}
public InputTypeRecurrent(long size, long timeSeriesLength){
this(size, timeSeriesLength, RNNFormat.NCW);
}
public InputTypeRecurrent(@JsonProperty("size") long size, @JsonProperty("timeSeriesLength") long timeSeriesLength) {
public InputTypeRecurrent(long size, RNNFormat format){
this(size, -1, format);
}
public InputTypeRecurrent(@JsonProperty("size") long size,
@JsonProperty("timeSeriesLength") long timeSeriesLength,
@JsonProperty("format") RNNFormat format) {
this.size = size;
this.timeSeriesLength = timeSeriesLength;
this.format = format;
}
@Override
@ -234,9 +251,9 @@ public abstract class InputType implements Serializable {
@Override
public String toString() {
if (timeSeriesLength > 0) {
return "InputTypeRecurrent(" + size + ",timeSeriesLength=" + timeSeriesLength + ")";
return "InputTypeRecurrent(" + size + ",timeSeriesLength=" + timeSeriesLength + ",format=" + format + ")";
} else {
return "InputTypeRecurrent(" + size + ")";
return "InputTypeRecurrent(" + size + ",format=" + format + ")";
}
}
@ -251,8 +268,23 @@ public abstract class InputType implements Serializable {
@Override
public long[] getShape(boolean includeBatchDim) {
if(includeBatchDim) return new long[]{-1, size, timeSeriesLength};
else return new long[]{size, timeSeriesLength};
if (includeBatchDim){
if (format == RNNFormat.NCW){
return new long[]{-1, size, timeSeriesLength};
}
else{
return new long[]{-1, timeSeriesLength, size};
}
}
else{
if (format == RNNFormat.NCW){
return new long[]{size, timeSeriesLength};
}
else{
return new long[]{timeSeriesLength, size};
}
}
}
}

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers;
import lombok.*;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.weights.IWeightInit;
@ -35,10 +36,12 @@ import java.util.List;
public abstract class BaseRecurrentLayer extends FeedForwardLayer {
protected IWeightInit weightInitFnRecurrent;
protected RNNFormat rnnDataFormat = RNNFormat.NCW;
protected BaseRecurrentLayer(Builder builder) {
super(builder);
this.weightInitFnRecurrent = builder.weightInitFnRecurrent;
this.rnnDataFormat = builder.rnnDataFormat;
}
@Override
@ -51,7 +54,7 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
return InputType.recurrent(nOut, itr.getTimeSeriesLength());
return InputType.recurrent(nOut, itr.getTimeSeriesLength(), itr.getFormat());
}
@Override
@ -64,12 +67,13 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
if (nIn <= 0 || override) {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
this.nIn = r.getSize();
this.rnnDataFormat = r.getFormat();
}
}
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, rnnDataFormat,getLayerName());
}
@NoArgsConstructor
@ -77,6 +81,12 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
@Setter
public static abstract class Builder<T extends Builder<T>> extends FeedForwardLayer.Builder<T> {
/**
* Set the format of data expected by the RNN. NCW = [miniBatchSize, size, timeSeriesLength],
* NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW.
*/
protected RNNFormat rnnDataFormat = RNNFormat.NCW;
/**
* Set constraints to be applied to the RNN recurrent weight parameters of this layer. Default: no
* constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters,
@ -163,5 +173,10 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
this.setWeightInitFnRecurrent(new WeightInitDistribution(dist));
return (T) this;
}
public T dataFormat(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
return (T)this;
}
}
}

View File

@ -22,6 +22,7 @@ import lombok.NoArgsConstructor;
import lombok.ToString;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.Convolution1DUtils;
@ -114,7 +115,7 @@ public class Convolution1DLayer extends ConvolutionLayer {
+ "\"): input is null");
}
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName());
}
public static class Builder extends ConvolutionLayer.BaseConvBuilder<Builder> {

View File

@ -87,7 +87,7 @@ public abstract class FeedForwardLayer extends BaseLayer {
return null;
case RNN:
//RNN -> FF
return new RnnToFeedForwardPreProcessor();
return new RnnToFeedForwardPreProcessor(((InputType.InputTypeRecurrent)inputType).getFormat());
case CNN:
//CNN -> FF
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;

View File

@ -22,6 +22,7 @@ import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
@ -528,7 +529,7 @@ public class InputTypeUtil {
}
}
public static InputPreProcessor getPreprocessorForInputTypeRnnLayers(InputType inputType, String layerName) {
public static InputPreProcessor getPreprocessorForInputTypeRnnLayers(InputType inputType, RNNFormat rnnDataFormat, String layerName) {
if (inputType == null) {
throw new IllegalStateException(
"Invalid input for RNN layer (layer name = \"" + layerName + "\"): input type is null");
@ -539,14 +540,14 @@ public class InputTypeUtil {
case CNNFlat:
//FF -> RNN or CNNFlat -> RNN
//In either case, input data format is a row vector per example
return new FeedForwardToRnnPreProcessor();
return new FeedForwardToRnnPreProcessor(rnnDataFormat);
case RNN:
//RNN -> RNN: No preprocessor necessary
return null;
case CNN:
//CNN -> RNN
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
return new CnnToRnnPreProcessor(c.getHeight(), c.getWidth(), c.getChannels());
return new CnnToRnnPreProcessor(c.getHeight(), c.getWidth(), c.getChannels(), rnnDataFormat);
default:
throw new RuntimeException("Unknown input type: " + inputType);
}

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers;
import lombok.*;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
@ -86,7 +87,7 @@ public class LearnedSelfAttentionLayer extends SameDiffLayer {
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName());
}
@Override

View File

@ -20,6 +20,7 @@ import lombok.*;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
@ -136,7 +137,7 @@ public class LocallyConnected1D extends SameDiffLayer {
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName());
}
@Override

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers;
import lombok.*;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
@ -92,7 +93,7 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName());
}
@Override

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
@ -53,12 +54,13 @@ import java.util.Map;
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class RnnLossLayer extends FeedForwardLayer {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
protected ILossFunction lossFn;
private RnnLossLayer(Builder builder) {
super(builder);
this.setLossFn(builder.lossFn);
this.rnnDataFormat = builder.rnnDataFormat;
}
@Override
@ -91,7 +93,7 @@ public class RnnLossLayer extends FeedForwardLayer {
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName());
}
@Override
@ -111,8 +113,9 @@ public class RnnLossLayer extends FeedForwardLayer {
public static class Builder extends BaseOutputLayer.Builder<Builder> {
public Builder() {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
public Builder() {
}
/**
@ -153,6 +156,14 @@ public class RnnLossLayer extends FeedForwardLayer {
"This layer has no parameters, thus nIn will always equal nOut.");
}
/**
* @param rnnDataFormat Data format expected by the layer. NCW = [miniBatchSize, size, timeSeriesLength],
* NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW.
*/
public Builder dataFormat(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
return this;
}
@Override
@SuppressWarnings("unchecked")
public RnnLossLayer build() {

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
@ -51,9 +52,11 @@ import java.util.Map;
@EqualsAndHashCode(callSuper = true)
public class RnnOutputLayer extends BaseOutputLayer {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
private RnnOutputLayer(Builder builder) {
super(builder);
initializeConstraints(builder);
this.rnnDataFormat = builder.rnnDataFormat;
}
@Override
@ -85,7 +88,7 @@ public class RnnOutputLayer extends BaseOutputLayer {
}
InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
return InputType.recurrent(nOut, itr.getTimeSeriesLength());
return InputType.recurrent(nOut, itr.getTimeSeriesLength(), itr.getFormat());
}
@Override
@ -97,18 +100,20 @@ public class RnnOutputLayer extends BaseOutputLayer {
if (nIn <= 0 || override) {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
this.rnnDataFormat = r.getFormat();
this.nIn = r.getSize();
}
}
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, rnnDataFormat, getLayerName());
}
public static class Builder extends BaseOutputLayer.Builder<Builder> {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
public Builder() {
//Set default activation function to softmax (to match default loss function MCXENT)
this.setActivationFn(new ActivationSoftmax());
@ -137,5 +142,14 @@ public class RnnOutputLayer extends BaseOutputLayer {
public RnnOutputLayer build() {
return new RnnOutputLayer(this);
}
/**
* @param rnnDataFormat Data format expected by the layer. NCW = [miniBatchSize, size, timeSeriesLength],
* NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW.
*/
public Builder dataFormat(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
return this;
}
}
}

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.conf.layers;
import lombok.*;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
@ -75,7 +76,7 @@ public class SelfAttentionLayer extends SameDiffLayer {
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName());
}
@Override

View File

@ -22,6 +22,7 @@ import lombok.NoArgsConstructor;
import lombok.ToString;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.Convolution1DUtils;
@ -105,7 +106,7 @@ public class Subsampling1DLayer extends SubsamplingLayer {
+ "\"): input is null");
}
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName());
}
@Override

View File

@ -20,6 +20,7 @@ import lombok.*;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
@ -104,7 +105,7 @@ public class ZeroPadding1DLayer extends NoParamLayer {
+ "\"): input is null");
}
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName());
}
@Override

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
@ -30,6 +31,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
import org.deeplearning4j.nn.params.BidirectionalParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater;
@ -124,6 +126,10 @@ public class Bidirectional extends Layer {
}
}
public RNNFormat getRNNDataFormat(){
return TimeSeriesUtils.getFormatFromRnnLayer(fwd);
}
@Override
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
@ -170,7 +176,7 @@ public class Bidirectional extends Layer {
} else {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) outOrig;
if (mode == Mode.CONCAT) {
return InputType.recurrent(2 * r.getSize());
return InputType.recurrent(2 * r.getSize(), getRNNDataFormat());
} else {
return r;
}

View File

@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode;
import lombok.NonNull;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
@ -29,17 +30,19 @@ import java.util.Collection;
@EqualsAndHashCode(callSuper = true)
public class TimeDistributed extends BaseWrapperLayer {
private final int timeAxis;
private RNNFormat rnnDataFormat = RNNFormat.NCW;
/**
* @param underlying Underlying (internal) layer - should be a feed forward type such as DenseLayer
* @param timeAxis Time axis, should be 2 for DL4J RNN activations (shape [minibatch, size, sequenceLength])
*/
public TimeDistributed(@JsonProperty("underlying") @NonNull Layer underlying, @JsonProperty("timeAxis") int timeAxis) {
public TimeDistributed(@JsonProperty("underlying") @NonNull Layer underlying, @JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) {
super(underlying);
this.timeAxis = timeAxis;
this.rnnDataFormat = rnnDataFormat;
}
public TimeDistributed(Layer underlying){
super(underlying);
}
@Override
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
@ -47,7 +50,7 @@ public class TimeDistributed extends BaseWrapperLayer {
NeuralNetConfiguration conf2 = conf.clone();
conf2.setLayer(((TimeDistributed) conf2.getLayer()).getUnderlying());
return new TimeDistributedLayer(underlying.instantiate(conf2, trainingListeners, layerIndex, layerParamsView,
initializeParams, networkDataType), timeAxis);
initializeParams, networkDataType), rnnDataFormat);
}
@Override
@ -59,7 +62,7 @@ public class TimeDistributed extends BaseWrapperLayer {
InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType;
InputType ff = InputType.feedForward(rnn.getSize());
InputType ffOut = underlying.getOutputType(layerIndex, ff);
return InputType.recurrent(ffOut.arrayElementsPerExample(), rnn.getTimeSeriesLength());
return InputType.recurrent(ffOut.arrayElementsPerExample(), rnn.getTimeSeriesLength(), rnnDataFormat);
}
@Override
@ -70,6 +73,7 @@ public class TimeDistributed extends BaseWrapperLayer {
InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType;
InputType ff = InputType.feedForward(rnn.getSize());
this.rnnDataFormat = rnn.getFormat();
underlying.setNIn(ff, override);
}

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.preprocessor;
import lombok.*;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.base.Preconditions;
@ -38,7 +39,7 @@ import java.util.Arrays;
* Functionally equivalent to combining CnnToFeedForwardPreProcessor + FeedForwardToRnnPreProcessor<br>
* Specifically, this does two things:<br>
* (a) Reshape 4d activations out of CNN layer, with shape [timeSeriesLength*miniBatchSize, numChannels, inputHeight, inputWidth])
* into 3d (time series) activations (with shape [numExamples, inputHeight*inputWidth*numChannels, timeSeriesLength])
* into 3d (time series) activations (with shape [miniBatchSize, inputHeight*inputWidth*numChannels, timeSeriesLength])
* for use in RNN layers<br>
* (b) Reshapes 3d epsilons (weights.*deltas) out of RNN layer (with shape
* [miniBatchSize,inputHeight*inputWidth*numChannels,timeSeriesLength]) into 4d epsilons with shape
@ -52,6 +53,7 @@ public class CnnToRnnPreProcessor implements InputPreProcessor {
private long inputHeight;
private long inputWidth;
private long numChannels;
private RNNFormat rnnDataFormat = RNNFormat.NCW;
@Getter(AccessLevel.NONE)
@Setter(AccessLevel.NONE)
@ -59,11 +61,20 @@ public class CnnToRnnPreProcessor implements InputPreProcessor {
@JsonCreator
public CnnToRnnPreProcessor(@JsonProperty("inputHeight") long inputHeight,
@JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) {
@JsonProperty("inputWidth") long inputWidth,
@JsonProperty("numChannels") long numChannels,
@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) {
this.inputHeight = inputHeight;
this.inputWidth = inputWidth;
this.numChannels = numChannels;
this.product = inputHeight * inputWidth * numChannels;
this.rnnDataFormat = rnnDataFormat;
}
public CnnToRnnPreProcessor(long inputHeight,
long inputWidth,
long numChannels){
this(inputHeight, inputWidth, numChannels, RNNFormat.NCW);
}
@Override
@ -90,14 +101,19 @@ public class CnnToRnnPreProcessor implements InputPreProcessor {
//Second: reshape 2d to 3d, as per FeedForwardToRnnPreProcessor
INDArray reshaped = workspaceMgr.dup(ArrayType.ACTIVATIONS, twod, 'f');
reshaped = reshaped.reshape('f', miniBatchSize, shape[0] / miniBatchSize, product);
if (rnnDataFormat == RNNFormat.NCW) {
return reshaped.permute(0, 2, 1);
}
return reshaped;
}
@Override
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
if (output.ordering() == 'c' || !Shape.hasDefaultStridesForShape(output))
output = output.dup('f');
if (rnnDataFormat == RNNFormat.NWC){
output = output.permute(0, 2, 1);
}
val shape = output.shape();
INDArray output2d;
if (shape[0] == 1) {
@ -122,7 +138,7 @@ public class CnnToRnnPreProcessor implements InputPreProcessor {
@Override
public CnnToRnnPreProcessor clone() {
return new CnnToRnnPreProcessor(inputHeight, inputWidth, numChannels);
return new CnnToRnnPreProcessor(inputHeight, inputWidth, numChannels, rnnDataFormat);
}
@Override
@ -133,7 +149,7 @@ public class CnnToRnnPreProcessor implements InputPreProcessor {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
val outSize = c.getChannels() * c.getHeight() * c.getWidth();
return InputType.recurrent(outSize);
return InputType.recurrent(outSize, rnnDataFormat);
}
@Override

View File

@ -21,6 +21,7 @@ import lombok.NoArgsConstructor;
import lombok.val;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -28,7 +29,7 @@ import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Arrays;
/**
@ -48,7 +49,11 @@ import java.util.Arrays;
@Data
@NoArgsConstructor
public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
public FeedForwardToRnnPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
//Need to reshape FF activations (2d) activations to 3d (for input into RNN layer)
@ -60,7 +65,10 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
val shape = input.shape();
INDArray reshaped = input.reshape('f', miniBatchSize, shape[0] / miniBatchSize, shape[1]);
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reshaped.permute(0, 2, 1));
if (rnnDataFormat == RNNFormat.NCW){
reshaped = reshaped.permute(0, 2, 1);
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reshaped);
}
@Override
@ -71,6 +79,9 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
"Invalid input: expect NDArray with rank 3 (i.e., epsilons from RNN layer)");
if (output.ordering() != 'f' || !Shape.hasDefaultStridesForShape(output))
output = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, output, 'f');
if (rnnDataFormat == RNNFormat.NWC){
output = output.permute(0, 2, 1);
}
val shape = output.shape();
INDArray ret;
@ -87,12 +98,7 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
@Override
public FeedForwardToRnnPreProcessor clone() {
try {
FeedForwardToRnnPreProcessor clone = (FeedForwardToRnnPreProcessor) super.clone();
return clone;
} catch (CloneNotSupportedException e) {
throw new RuntimeException(e);
}
return new FeedForwardToRnnPreProcessor(rnnDataFormat);
}
@Override
@ -104,10 +110,10 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
if (inputType.getType() == InputType.Type.FF) {
InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward) inputType;
return InputType.recurrent(ff.getSize());
return InputType.recurrent(ff.getSize(), rnnDataFormat);
} else {
InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat) inputType;
return InputType.recurrent(cf.getFlattenedSize());
return InputType.recurrent(cf.getFlattenedSize(), rnnDataFormat);
}
}

View File

@ -19,8 +19,10 @@ package org.deeplearning4j.nn.conf.preprocessor;
import lombok.*;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.primitives.Pair;
@ -52,19 +54,27 @@ public class RnnToCnnPreProcessor implements InputPreProcessor {
private int inputHeight;
private int inputWidth;
private int numChannels;
private RNNFormat rnnDataFormat = RNNFormat.NCW;
@Getter(AccessLevel.NONE)
@Setter(AccessLevel.NONE)
private int product;
public RnnToCnnPreProcessor(@JsonProperty("inputHeight") int inputHeight,
@JsonProperty("inputWidth") int inputWidth, @JsonProperty("numChannels") int numChannels) {
@JsonProperty("inputWidth") int inputWidth,
@JsonProperty("numChannels") int numChannels,
@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) {
this.inputHeight = inputHeight;
this.inputWidth = inputWidth;
this.numChannels = numChannels;
this.product = inputHeight * inputWidth * numChannels;
this.rnnDataFormat = rnnDataFormat;
}
public RnnToCnnPreProcessor(int inputHeight,
int inputWidth,
int numChannels){
this(inputHeight, inputWidth, numChannels, RNNFormat.NCW);
}
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
@ -72,6 +82,9 @@ public class RnnToCnnPreProcessor implements InputPreProcessor {
input = input.dup('f');
//Input: 3d activations (RNN)
//Output: 4d activations (CNN)
if (rnnDataFormat == RNNFormat.NWC){
input = input.permute(0, 2, 1);
}
val shape = input.shape();
INDArray in2d;
if (shape[0] == 1) {
@ -98,14 +111,17 @@ public class RnnToCnnPreProcessor implements InputPreProcessor {
val shape = output.shape();
//First: reshape 4d to 2d
INDArray twod = output.reshape('c', output.size(0), ArrayUtil.prod(output.shape()) / output.size(0));
//Second: reshape 2d to 4d
//Second: reshape 2d to 3d
INDArray reshaped = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, twod, 'f').reshape('f', miniBatchSize, shape[0] / miniBatchSize, product);
return reshaped.permute(0, 2, 1);
if (rnnDataFormat == RNNFormat.NCW) {
reshaped = reshaped.permute(0, 2, 1);
}
return reshaped;
}
@Override
public RnnToCnnPreProcessor clone() {
return new RnnToCnnPreProcessor(inputHeight, inputWidth, numChannels);
return new RnnToCnnPreProcessor(inputHeight, inputWidth, numChannels, rnnDataFormat);
}
@Override

View File

@ -16,11 +16,14 @@
package org.deeplearning4j.nn.conf.preprocessor;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -28,6 +31,7 @@ import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Arrays;
@ -47,8 +51,14 @@ import java.util.Arrays;
*/
@Data
@Slf4j
@NoArgsConstructor
public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
public RnnToFeedForwardPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
//Need to reshape RNN activations (3d) activations to 2d (for input into feed forward layer)
@ -59,10 +69,13 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
if (input.ordering() != 'f' || !Shape.hasDefaultStridesForShape(input))
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'f');
if (rnnDataFormat == RNNFormat.NWC){
input = input.permute(0, 2, 1);
}
val shape = input.shape();
INDArray ret;
if (shape[0] == 1) {
ret = input.tensorAlongDimension(0, 1, 2).permutei(1, 0); //Edge case: miniBatchSize==1
ret = input.tensorAlongDimension(0, 1, 2).permute(1, 0); //Edge case: miniBatchSize==1
} else if (shape[2] == 1) {
ret = input.tensorAlongDimension(0, 1, 0); //Edge case: timeSeriesLength=1
} else {
@ -85,17 +98,15 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
val shape = output.shape();
INDArray reshaped = output.reshape('f', miniBatchSize, shape[0] / miniBatchSize, shape[1]);
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, reshaped.permute(0, 2, 1));
if (rnnDataFormat == RNNFormat.NCW){
reshaped = reshaped.permute(0, 2, 1);
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, reshaped);
}
@Override
public RnnToFeedForwardPreProcessor clone() {
try {
RnnToFeedForwardPreProcessor clone = (RnnToFeedForwardPreProcessor) super.clone();
return clone;
} catch (CloneNotSupportedException e) {
throw new RuntimeException(e);
}
return new RnnToFeedForwardPreProcessor(rnnDataFormat);
}
@Override

View File

@ -18,7 +18,10 @@ package org.deeplearning4j.nn.layers.recurrent;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -26,7 +29,7 @@ import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public abstract class BaseRecurrentLayer<LayerConfT extends org.deeplearning4j.nn.conf.layers.BaseLayer>
public abstract class BaseRecurrentLayer<LayerConfT extends org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer>
extends BaseLayer<LayerConfT> implements RecurrentLayer {
/**
@ -85,4 +88,19 @@ public abstract class BaseRecurrentLayer<LayerConfT extends org.deeplearning4j.n
tBpttStateMap.putAll(state);
}
public RNNFormat getDataFormat(){
return layerConf().getRnnDataFormat();
}
protected INDArray permuteIfNWC(INDArray arr){
if (arr == null){
return null;
}
if (getDataFormat() == RNNFormat.NWC){
return arr.permute(0, 2, 1);
}
return arr;
}
}

View File

@ -25,6 +25,7 @@ import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
@ -78,6 +79,9 @@ public class BidirectionalLayer implements RecurrentLayer {
this.paramsView = paramsView;
}
private RNNFormat getRNNDataFormat(){
return layerConf.getRNNDataFormat();
}
@Override
public INDArray rnnTimeStep(INDArray input, LayerWorkspaceMgr workspaceMgr) {
throw new UnsupportedOperationException("Cannot RnnTimeStep bidirectional layers");
@ -140,7 +144,10 @@ public class BidirectionalLayer implements RecurrentLayer {
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
INDArray eFwd;
INDArray eBwd;
boolean permute = getRNNDataFormat() == RNNFormat.NWC && epsilon.rank() == 3;
if (permute){
epsilon = epsilon.permute(0, 2, 1);
}
val n = epsilon.size(1)/2;
switch (layerConf.getMode()){
case ADD:
@ -165,6 +172,10 @@ public class BidirectionalLayer implements RecurrentLayer {
eBwd = TimeSeriesUtils.reverseTimeSeries(eBwd, workspaceMgr, ArrayType.BP_WORKING_MEM);
if (permute){
eFwd = eFwd.permute(0, 2, 1);
eBwd = eBwd.permute(0, 2, 1);
}
Pair<Gradient,INDArray> g1 = fwd.backpropGradient(eFwd, workspaceMgr);
Pair<Gradient,INDArray> g2 = bwd.backpropGradient(eBwd, workspaceMgr);
@ -176,7 +187,9 @@ public class BidirectionalLayer implements RecurrentLayer {
g.gradientForVariable().put(BidirectionalParamInitializer.BACKWARD_PREFIX + e.getKey(), e.getValue());
}
INDArray g2Reversed = TimeSeriesUtils.reverseTimeSeries(g2.getRight(), workspaceMgr, ArrayType.BP_WORKING_MEM);
INDArray g2Right = permute ? g2.getRight().permute(0, 2, 1): g2.getRight();
INDArray g2Reversed = TimeSeriesUtils.reverseTimeSeries(g2Right, workspaceMgr, ArrayType.BP_WORKING_MEM);
g2Reversed = permute? g2Reversed.permute(0, 2, 1): g2Reversed;
INDArray epsOut = g1.getRight().addi(g2Reversed);
return new Pair<>(g, epsOut);
@ -186,25 +199,38 @@ public class BidirectionalLayer implements RecurrentLayer {
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
INDArray out1 = fwd.activate(training, workspaceMgr);
INDArray out2 = bwd.activate(training, workspaceMgr);
boolean permute = getRNNDataFormat() == RNNFormat.NWC && out1.rank() == 3;
if(permute){
out1 = out1.permute(0, 2, 1);
out2 = out2.permute(0, 2, 1);
}
//Reverse the output time series. Note: when using LastTimeStepLayer, output can be rank 2
out2 = out2.rank() == 2 ? out2 : TimeSeriesUtils.reverseTimeSeries(out2, workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray ret;
switch (layerConf.getMode()){
case ADD:
return out1.addi(out2);
ret = out1.addi(out2);
break;
case MUL:
//TODO may be more efficient ways than this...
this.outFwd = out1.detach();
this.outBwd = out2.detach();
return workspaceMgr.dup(ArrayType.ACTIVATIONS, out1).muli(out2);
ret = workspaceMgr.dup(ArrayType.ACTIVATIONS, out1).muli(out2);
break;
case AVERAGE:
return out1.addi(out2).muli(0.5);
ret = out1.addi(out2).muli(0.5);
break;
case CONCAT:
INDArray ret = Nd4j.concat(1, out1, out2);
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret);
ret = Nd4j.concat(1, out1, out2);
ret = workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret);
break;
default:
throw new RuntimeException("Unknown mode: " + layerConf.getMode());
}
if (permute){
ret = ret.permute(0, 2, 1);
}
return ret;
}
@Override
@ -465,7 +491,9 @@ public class BidirectionalLayer implements RecurrentLayer {
public void setInput(INDArray input, LayerWorkspaceMgr layerWorkspaceMgr) {
this.input = input;
fwd.setInput(input, layerWorkspaceMgr);
if (getRNNDataFormat() == RNNFormat.NWC){
input = input.permute(0, 2, 1);
}
INDArray reversed;
if(!input.isAttached()){
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
@ -478,6 +506,9 @@ public class BidirectionalLayer implements RecurrentLayer {
reversed = TimeSeriesUtils.reverseTimeSeries(input);
}
}
if (getRNNDataFormat() == RNNFormat.NWC){
reversed = reversed.permute(0, 2, 1);
}
bwd.setInput(reversed, layerWorkspaceMgr);
}

View File

@ -88,12 +88,12 @@ public class GravesBidirectionalLSTM
}
final FwdPassReturn fwdPass = activateHelperDirectional(true, null, null, true, true, workspaceMgr);
fwdPass.fwdPassOutput = permuteIfNWC(fwdPass.fwdPassOutput);
final Pair<Gradient, INDArray> forwardsGradient = LSTMHelpers.backpropGradientHelper(this,
this.conf,
this.layerConf().getGateActivationFn(), this.input,
this.layerConf().getGateActivationFn(), permuteIfNWC(this.input),
getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS),
getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), epsilon,
getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), permuteIfNWC(epsilon),
truncatedBPTT, tbpttBackwardLength, fwdPass, true,
GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS,
GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS,
@ -106,16 +106,17 @@ public class GravesBidirectionalLSTM
final Pair<Gradient, INDArray> backwardsGradient = LSTMHelpers.backpropGradientHelper(this,
this.conf,
this.layerConf().getGateActivationFn(), this.input,
this.layerConf().getGateActivationFn(), permuteIfNWC(this.input),
getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS),
getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), epsilon,
getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), permuteIfNWC(epsilon),
truncatedBPTT, tbpttBackwardLength, backPass, false,
GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS,
GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS,
GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS, gradientViews, maskArray, true,
null, workspaceMgr, layerConf().isHelperAllowFallback());
forwardsGradient.setSecond(permuteIfNWC(forwardsGradient.getSecond()));
backwardsGradient.setSecond(permuteIfNWC(backwardsGradient.getSecond()));
//merge the gradient, which is key value pair of String,INDArray
//the keys for forwards and backwards should be different
@ -171,7 +172,7 @@ public class GravesBidirectionalLSTM
} else {
forwardsEval = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(),
this.input, getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS),
permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS),
getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS),
getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS), training, null, null,
forBackprop || (cacheMode != CacheMode.NONE && training), true,
@ -179,7 +180,7 @@ public class GravesBidirectionalLSTM
forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback());
backwardsEval = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(),
this.input,
permuteIfNWC(this.input),
getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS),
getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS),
getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS), training, null, null,
@ -187,6 +188,8 @@ public class GravesBidirectionalLSTM
GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, maskArray, true, null,
forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback());
forwardsEval.fwdPassOutput = permuteIfNWC(forwardsEval.fwdPassOutput);
backwardsEval.fwdPassOutput = permuteIfNWC(backwardsEval.fwdPassOutput);
cachedPassForward = forwardsEval;
cachedPassBackward = backwardsEval;
}
@ -228,10 +231,12 @@ public class GravesBidirectionalLSTM
biasKey = GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS;
}
return LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), this.input,
FwdPassReturn ret = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input),
getParam(recurrentKey), getParam(inputKey), getParam(biasKey), training,
prevOutputActivations, prevMemCellState, forBackprop, forwards, inputKey, maskArray, true,
null, forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback());
ret.fwdPassOutput = permuteIfNWC(ret.fwdPassOutput);
return ret;
}
}

View File

@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
import org.nd4j.base.Preconditions;
@ -89,17 +90,17 @@ public class GravesLSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.la
} else {
fwdPass = activateHelper(true, null, null, true, workspaceMgr);
}
fwdPass.fwdPassOutput = permuteIfNWC(fwdPass.fwdPassOutput);
Pair<Gradient, INDArray> p = LSTMHelpers.backpropGradientHelper(this,
this.conf, this.layerConf().getGateActivationFn(), this.input,
recurrentWeights, inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, true,
this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input),
recurrentWeights, inputWeights, permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true,
GravesLSTMParamInitializer.INPUT_WEIGHT_KEY, GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY,
GravesLSTMParamInitializer.BIAS_KEY, gradientViews, maskArray, true, null,
workspaceMgr, layerConf().isHelperAllowFallback());
weightNoiseParams.clear();
p.setSecond(backpropDropOutIfPresent(p.getSecond()));
p.setSecond(permuteIfNWC(backpropDropOutIfPresent(p.getSecond())));
return p;
}
@ -117,8 +118,8 @@ public class GravesLSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.la
private FwdPassReturn activateHelper(final boolean training, final INDArray prevOutputActivations,
final INDArray prevMemCellState, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) {
assertInputSet(false);
Preconditions.checkState(input.rank() == 3,
"3D input expected to RNN layer expected, got " + input.rank());
Preconditions.checkState(this.input.rank() == 3,
"3D input expected to RNN layer expected, got " + this.input.rank());
applyDropOutIfNecessary(training, workspaceMgr);
// if (cacheMode == null)
@ -136,18 +137,17 @@ public class GravesLSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.la
final INDArray recurrentWeights = getParamWithNoise(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY, training, workspaceMgr); //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray inputWeights = getParamWithNoise(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY, training, workspaceMgr); //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
final INDArray biases = getParamWithNoise(GravesLSTMParamInitializer.BIAS_KEY, training, workspaceMgr); //by row: IFOG //Shape: [4,hiddenLayerSize]; order: [bi,bf,bo,bg]^T
INDArray input = permuteIfNWC(this.input);
FwdPassReturn fwd = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(),
this.input, recurrentWeights, inputWeights, biases, training, prevOutputActivations,
input, recurrentWeights, inputWeights, biases, training, prevOutputActivations,
prevMemCellState, forBackprop || (cacheMode != CacheMode.NONE && training), true,
GravesLSTMParamInitializer.INPUT_WEIGHT_KEY, maskArray, true, null,
cacheMode, workspaceMgr, layerConf().isHelperAllowFallback());
fwd.fwdPassOutput = permuteIfNWC(fwd.fwdPassOutput);
if (training && cacheMode != CacheMode.NONE) {
cachedFwdPass = fwd;
}
return fwd;
}

View File

@ -123,17 +123,16 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
} else {
fwdPass = activateHelper(true, null, null, true, workspaceMgr);
}
fwdPass.fwdPassOutput = permuteIfNWC(fwdPass.fwdPassOutput);
Pair<Gradient,INDArray> p = LSTMHelpers.backpropGradientHelper(this,
this.conf, this.layerConf().getGateActivationFn(), this.input,
recurrentWeights, inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, true,
this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input),
recurrentWeights, inputWeights, permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true,
LSTMParamInitializer.INPUT_WEIGHT_KEY, LSTMParamInitializer.RECURRENT_WEIGHT_KEY,
LSTMParamInitializer.BIAS_KEY, gradientViews, null, false, helper, workspaceMgr,
layerConf().isHelperAllowFallback());
weightNoiseParams.clear();
p.setSecond(backpropDropOutIfPresent(p.getSecond()));
p.setSecond(permuteIfNWC(backpropDropOutIfPresent(p.getSecond())));
return p;
}
@ -167,17 +166,18 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
final INDArray recurrentWeights = getParamWithNoise(LSTMParamInitializer.RECURRENT_WEIGHT_KEY, training, workspaceMgr); //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray inputWeights = getParamWithNoise(LSTMParamInitializer.INPUT_WEIGHT_KEY, training, workspaceMgr); //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
final INDArray biases = getParamWithNoise(LSTMParamInitializer.BIAS_KEY, training, workspaceMgr); //by row: IFOG //Shape: [4,hiddenLayerSize]; order: [bi,bf,bo,bg]^T
INDArray input = permuteIfNWC(this.input);
FwdPassReturn fwd = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(),
this.input, recurrentWeights, inputWeights, biases, training, prevOutputActivations,
input, recurrentWeights, inputWeights, biases, training, prevOutputActivations,
prevMemCellState, (training && cacheMode != CacheMode.NONE) || forBackprop, true,
LSTMParamInitializer.INPUT_WEIGHT_KEY, maskArray, false, helper,
forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback());
fwd.fwdPassOutput = permuteIfNWC(fwd.fwdPassOutput);
if (training && cacheMode != CacheMode.NONE) {
cachedFwdPass = fwd;
}
return fwd;
}

View File

@ -465,7 +465,6 @@ public class LSTMHelpers {
val miniBatchSize = epsilon.size(0);
boolean is2dInput = epsilon.rank() < 3; //Edge case: T=1 may have shape [miniBatchSize,n^(L+1)], equiv. to [miniBatchSize,n^(L+1),1]
val timeSeriesLength = (is2dInput ? 1 : epsilon.size(2));
INDArray wFFTranspose = null;
INDArray wOOTranspose = null;
INDArray wGGTranspose = null;
@ -573,14 +572,14 @@ public class LSTMHelpers {
nablaCellState = Nd4j.create(inputWeights.dataType(), new long[]{miniBatchSize, hiddenLayerSize}, 'f');
}
INDArray prevMemCellState = (iTimeIndex == 0 ? fwdPass.prevMemCell : fwdPass.memCellState[(int) (time - inext)]);
INDArray prevMemCellState = (iTimeIndex == 0 ? fwdPass.prevMemCell : fwdPass.memCellState[(time - inext)]);
INDArray prevHiddenUnitActivation =
(iTimeIndex == 0 ? fwdPass.prevAct : fwdPass.fwdPassOutputAsArrays[(int) (time - inext)]);
INDArray currMemCellState = fwdPass.memCellState[(int) time];
(iTimeIndex == 0 ? fwdPass.prevAct : fwdPass.fwdPassOutputAsArrays[(time - inext)]);
INDArray currMemCellState = fwdPass.memCellState[time];
//LSTM unit output errors (dL/d(a_out)); not to be confused with \delta=dL/d(z_out)
INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension((int) time, 1, 0)); //(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv.
INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension(time, 1, 0)); //(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv.
INDArray nablaOut = Shape.toOffsetZeroCopy(epsilonSlice, 'f'); //Shape: [m,n^L]
if (iTimeIndex != timeSeriesLength - 1) {
//if t == timeSeriesLength-1 then deltaiNext etc are zeros
@ -666,7 +665,7 @@ public class LSTMHelpers {
//Mask array is present: bidirectional RNN -> need to zero out these errors to avoid using errors from a masked time step
// to calculate the parameter gradients. Mask array has shape [minibatch, timeSeriesLength] -> get column(this time step)
timeStepMaskColumn = maskArray.getColumn(time, true);
deltaifogNext.muliColumnVector(timeStepMaskColumn);
deltaifogNext.muli(timeStepMaskColumn);
//Later, the deltaifogNext is used to calculate: input weight gradients, recurrent weight gradients, bias gradients
}
@ -737,7 +736,7 @@ public class LSTMHelpers {
if (maskArray != null) {
//Mask array is present: bidirectional RNN -> need to zero out these errors to avoid sending anything
// but 0s to the layer below at this time step (for the given example)
epsilonNextSlice.muliColumnVector(timeStepMaskColumn);
epsilonNextSlice.muli(timeStepMaskColumn);
}
}
}

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.recurrent;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.util.TimeSeriesUtils;
@ -59,11 +60,32 @@ public class LastTimeStepLayer extends BaseWrapperLayer {
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
INDArray newEps = Nd4j.create(epsilon.dataType(), origOutputShape, 'f');
long[] newEpsShape = origOutputShape;
boolean nwc = (underlying instanceof BaseRecurrentLayer &&
((BaseRecurrentLayer) underlying).getDataFormat() == RNNFormat.NWC)||
(underlying instanceof MaskZeroLayer && ((MaskZeroLayer)underlying).getUnderlying() instanceof
BaseRecurrentLayer && ((BaseRecurrentLayer)((MaskZeroLayer)underlying).getUnderlying()).getDataFormat()
== RNNFormat.NWC);
INDArray newEps = Nd4j.create(epsilon.dataType(), newEpsShape, 'f');
if(lastTimeStepIdxs == null){
//no mask case
if (nwc){
newEps.put(new INDArrayIndex[]{all(), point(origOutputShape[1]-1), all()}, epsilon);
}
else{
newEps.put(new INDArrayIndex[]{all(), all(), point(origOutputShape[2]-1)}, epsilon);
}
} else {
if (nwc){
INDArrayIndex[] arr = new INDArrayIndex[]{null, null, all()};
//TODO probably possible to optimize this with reshape + scatter ops...
for( int i=0; i<lastTimeStepIdxs.length; i++ ){
arr[0] = point(i);
arr[1] = point(lastTimeStepIdxs[i]);
newEps.put(arr, epsilon.getRow(i));
}
}
else{
INDArrayIndex[] arr = new INDArrayIndex[]{null, all(), null};
//TODO probably possible to optimize this with reshape + scatter ops...
for( int i=0; i<lastTimeStepIdxs.length; i++ ){
@ -72,6 +94,8 @@ public class LastTimeStepLayer extends BaseWrapperLayer {
newEps.put(arr, epsilon.getRow(i));
}
}
}
return underlying.backpropGradient(newEps, workspaceMgr);
}
@ -103,10 +127,18 @@ public class LastTimeStepLayer extends BaseWrapperLayer {
"rank " + in.rank() + " with shape " + Arrays.toString(in.shape()));
}
origOutputShape = in.shape();
boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(underlying.conf().getLayer()) == RNNFormat.NWC;
// underlying instanceof BaseRecurrentLayer && ((BaseRecurrentLayer)underlying).getDataFormat() == RNNFormat.NWC)||
// underlying instanceof MaskZeroLayer && ((MaskZeroLayer)underlying).getUnderlying() instanceof BaseRecurrentLayer &&
// ((BaseRecurrentLayer)((MaskZeroLayer)underlying).getUnderlying()).getDataFormat() == RNNFormat.NWC;
if (nwc){
in = in.permute(0, 2, 1);
}
INDArray mask = underlying.getMaskArray();
Pair<INDArray,int[]> p = TimeSeriesUtils.pullLastTimeSteps(in, mask, workspaceMgr, arrayType);
lastTimeStepIdxs = p.getSecond();
return p.getFirst();
}
}

View File

@ -30,6 +30,9 @@ import org.nd4j.linalg.primitives.Pair;
import lombok.NonNull;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import static org.deeplearning4j.nn.conf.RNNFormat.NCW;
import static org.deeplearning4j.nn.conf.RNNFormat.NWC;
/**
* Masks timesteps with activation equal to the specified masking value, defaulting to 0.0.
* Assumes that the input shape is [batch_size, input_size, timesteps].
@ -76,7 +79,11 @@ public class MaskZeroLayer extends BaseWrapperLayer {
throw new IllegalArgumentException("Expected input of shape [batch_size, timestep_input_size, timestep], " +
"got shape "+Arrays.toString(input.shape()) + " instead");
}
INDArray mask = input.eq(maskingValue).castTo(input.dataType()).sum(1).neq(input.shape()[1]);
if ((underlying instanceof BaseRecurrentLayer &&
((BaseRecurrentLayer)underlying).getDataFormat() == NWC)){
input = input.permute(0, 2, 1);
}
INDArray mask = input.eq(maskingValue).castTo(input.dataType()).sum(1).neq(input.shape()[1]).castTo(input.dataType());
underlying.setMaskArray(mask.detach());
}

View File

@ -22,6 +22,7 @@ import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
@ -60,6 +61,8 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
assertInputSet(true);
INDArray input = this.input;
INDArray labels = this.labels;
if (input.rank() != 3)
throw new UnsupportedOperationException(
"Input is not rank 3. Expected rank 3 input of shape [minibatch, size, sequenceLength]. Got input with rank " +
@ -67,6 +70,10 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
if (labels == null)
throw new IllegalStateException("Labels are not set (null)");
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
input = input.permute(0, 2, 1);
labels = labels.permute(0, 2, 1);
}
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
Preconditions.checkState(input.size(2) == labels.size(2), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
@ -90,7 +97,9 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
INDArray delta2d = lossFunction.computeGradient(labels2d, input2d.dup(input2d.ordering()), layerConf().getActivationFn(), maskReshaped);
INDArray delta3d = TimeSeriesUtils.reshape2dTo3d(delta2d, input.size(0), workspaceMgr, ArrayType.ACTIVATION_GRAD);
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
delta3d = delta3d.permute(0, 2, 1);
}
// grab the empty gradient
Gradient gradient = new DefaultGradient();
return new Pair<>(gradient, delta3d);
@ -159,13 +168,21 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
@Override
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
assertInputSet(false);
INDArray input = this.input;
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
input = input.permute(0, 2, 1);
}
if (input.rank() != 3)
throw new UnsupportedOperationException(
"Input must be rank 3. Got input with rank " + input.rank() + " " + layerId());
INDArray as2d = TimeSeriesUtils.reshape3dTo2d(input);
INDArray out2d = layerConf().getActivationFn().getActivation(workspaceMgr.dup(ArrayType.ACTIVATIONS, as2d, as2d.ordering()), training);
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, TimeSeriesUtils.reshape2dTo3d(out2d, input.size(0), workspaceMgr, ArrayType.ACTIVATIONS));
INDArray ret = workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, TimeSeriesUtils.reshape2dTo3d(out2d, input.size(0), workspaceMgr, ArrayType.ACTIVATIONS));
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
ret = ret.permute(0, 2, 1);
}
return ret;
}
@Override
@ -196,6 +213,12 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
@Override
public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) {
INDArray input = this.input;
INDArray labels = this.labels;
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
input = input.permute(0, 2, 1);
labels = input.permute(0, 2, 1);
}
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray labels2d = TimeSeriesUtils.reshape3dTo2d(labels, workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray maskReshaped;
@ -228,10 +251,14 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
@Override
public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr workspaceMgr) {
//For RNN: need to sum up the score over each time step before returning.
INDArray input = this.input;
INDArray labels = this.labels;
if (input == null || labels == null)
throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
input = input.permute(0, 2, 1);
labels = input.permute(0, 2, 1);
}
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray labels2d = TimeSeriesUtils.reshape3dTo2d(labels, workspaceMgr, ArrayType.FF_WORKING_MEM);

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.recurrent;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
@ -57,11 +58,15 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
"Input is not rank 3. RnnOutputLayer expects rank 3 input with shape [minibatch, layerInSize, sequenceLength]." +
" Got input with rank " + input.rank() + " and shape " + Arrays.toString(input.shape()) + " - " + layerId());
}
int td = (layerConf().getRnnDataFormat()==RNNFormat.NCW)? 2: 1;
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
Preconditions.checkState(input.size(2) == labels.size(2), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
Preconditions.checkState(input.size(td) == labels.size(td), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
INDArray inputTemp = input;
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
this.input = input.permute(0, 2, 1);
}
this.input = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.BP_WORKING_MEM);
applyDropOutIfNecessary(true, workspaceMgr); //Edge case: we skip OutputLayer forward pass during training as this isn't required to calculate gradients
@ -71,7 +76,9 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
INDArray epsilon2d = gradAndEpsilonNext.getSecond();
INDArray epsilon3d = TimeSeriesUtils.reshape2dTo3d(epsilon2d, input.size(0), workspaceMgr, ArrayType.ACTIVATION_GRAD);
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
epsilon3d = epsilon3d.permute(0, 2, 1);
}
weightNoiseParams.clear();
//epsilon3d = backpropDropOutIfPresent(epsilon3d);
@ -104,6 +111,7 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
if (input.rank() == 3) {
//Case when called from RnnOutputLayer
INDArray inputTemp = input;
input = (layerConf().getRnnDataFormat()==RNNFormat.NWC)? input.permute(0, 2, 1):input;
input = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray out = super.preOutput(training, workspaceMgr);
this.input = inputTemp;
@ -117,13 +125,17 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
@Override
protected INDArray getLabels2d(LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) {
if (labels.rank() == 3)
INDArray labels = this.labels;
if (labels.rank() == 3){
labels = (layerConf().getRnnDataFormat()==RNNFormat.NWC)?labels.permute(0, 2, 1):labels;
return TimeSeriesUtils.reshape3dTo2d(labels, workspaceMgr, arrayType);
}
return labels;
}
@Override
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
INDArray input = this.input;
if (input.rank() != 3)
throw new UnsupportedOperationException(
"Input must be rank 3. Got input with rank " + input.rank() + " " + layerId());
@ -131,6 +143,9 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr);
applyDropOutIfNecessary(training, workspaceMgr);
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
input = input.permute(0, 2, 1);
}
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input.castTo(W.dataType()), workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray act2d = layerConf().getActivationFn().getActivation(input2d.mmul(W).addiRowVector(b), training);
@ -144,7 +159,11 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
}
}
return TimeSeriesUtils.reshape2dTo3d(act2d, input.size(0), workspaceMgr, ArrayType.ACTIVATIONS);
INDArray ret = TimeSeriesUtils.reshape2dTo3d(act2d, input.size(0), workspaceMgr, ArrayType.ACTIVATIONS);
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
ret = ret.permute(0, 2, 1);
}
return ret;
}
@Override

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.recurrent;
import lombok.val;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
@ -50,6 +51,7 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn> {
public static final String STATE_KEY_PREV_ACTIVATION = "prevAct";
public SimpleRnn(NeuralNetConfiguration conf, DataType dataType) {
super(conf, dataType);
}
@ -92,6 +94,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
val nOut = layerConf().getNOut();
INDArray input = this.input.castTo(dataType); //No-op if correct type
input = permuteIfNWC(input);
//First: Do forward pass to get gate activations and Zs
Quad<INDArray,INDArray, INDArray, INDArray> p = activateHelper(null, true, true, workspaceMgr);
@ -125,8 +128,9 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
} else {
end = 0;
}
epsilon = permuteIfNWC(epsilon);
for( long i = tsLength-1; i>= end; i--){
INDArray dldaCurrent = epsilon.get(all(), all(), point(i));
INDArray dldaCurrent = epsilon.get(all(), all(), point(i)).dup();
INDArray aCurrent = p.getFirst().get(all(), all(), point(i));
INDArray zCurrent = p.getSecond().get(all(), all(), point(i));
INDArray nCurrent = (hasLayerNorm() ? p.getThird().get(all(), all(), point(i)) : null);
@ -141,7 +145,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
//Recurrent weight gradients:
Nd4j.gemm(aCurrent, dldzNext, rwg, true, false, 1.0, 1.0);
}
INDArray dldzCurrent = a.backprop(zCurrent.dup(), dldaCurrent.dup()).getFirst();
INDArray dldzCurrent = a.backprop(zCurrent.dup(), dldaCurrent).getFirst();
//Handle masking
INDArray maskCol = null;
@ -200,6 +204,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
}
epsOut = backpropDropOutIfPresent(epsOut);
epsOut = permuteIfNWC(epsOut);
return new Pair<>(grad, epsOut);
}
@ -224,6 +229,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
applyDropOutIfNecessary(training, workspaceMgr);
INDArray input = this.input.castTo(dataType); //No-op if correct type
input = permuteIfNWC(input);
val m = input.size(0);
val tsLength = input.size(2);
val nOut = layerConf().getNOut();
@ -300,7 +306,12 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
Nd4j.getExecutioner().exec(new BroadcastMulOp(outZ, mask, outZ, 0, 2));
}
}
if (!forBackprop) {
out = permuteIfNWC(out);
outZ = permuteIfNWC(outZ);
outPreNorm = permuteIfNWC(outPreNorm);
recPreNorm = permuteIfNWC(recPreNorm);
}
return new Quad<>(out, outZ, outPreNorm, recPreNorm);
}

View File

@ -2,6 +2,7 @@ package org.deeplearning4j.nn.layers.recurrent;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
@ -22,11 +23,11 @@ import org.nd4j.linalg.util.ArrayUtil;
*/
public class TimeDistributedLayer extends BaseWrapperLayer {
private final int timeAxis;
private RNNFormat rnnDataFormat;
public TimeDistributedLayer(Layer underlying, int timeAxis) {
public TimeDistributedLayer(Layer underlying, RNNFormat rnnDataFormat) {
super(underlying);
this.timeAxis = timeAxis;
this.rnnDataFormat = rnnDataFormat;
}
@ -56,7 +57,7 @@ public class TimeDistributedLayer extends BaseWrapperLayer {
protected INDArray reshape(INDArray array){
//Reshape the time axis to the minibatch axis
//For example, for RNN -> FF (dense time distributed): [mb, size, seqLen] -> [mb x seqLen, size]
int axis = timeAxis;
int axis = (rnnDataFormat == RNNFormat.NCW) ? 2 : 1;
if(axis < 0)
axis += array.rank();
@ -91,7 +92,7 @@ public class TimeDistributedLayer extends BaseWrapperLayer {
protected INDArray revertReshape(INDArray toRevert, long minibatch){
int axis = timeAxis;
int axis = (rnnDataFormat == RNNFormat.NCW)? 2 : 1;
if(axis < 0)
axis += (toRevert.rank()+1);

View File

@ -17,6 +17,13 @@
package org.deeplearning4j.util;
import lombok.val;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
@ -233,6 +240,12 @@ public class TimeSeriesUtils {
return outReshape.reshape('f', in.size(0), in.size(1), in.size(2));
}
public static INDArray reverseTimeSeries(INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType, RNNFormat dataFormat){
if (dataFormat == RNNFormat.NCW){
return reverseTimeSeries(in, workspaceMgr, arrayType);
}
return reverseTimeSeries(in.permute(0, 2, 1), workspaceMgr, arrayType).permute(0, 2, 1);
}
/**
* Reverse an input time series along the time dimension
*
@ -423,4 +436,25 @@ public class TimeSeriesUtils {
return new Pair<>(workspaceMgr.leverageTo(arrayType, out), fwdPassTimeSteps);
}
/**
* Get the {@link RNNFormat} from the RNN layer, accounting for the presence of wrapper layers like Bidirectional,
* LastTimeStep, etc
* @param layer Layer to get the RNNFormat from
*/
public static RNNFormat getFormatFromRnnLayer(Layer layer){
if(layer instanceof BaseRecurrentLayer){
return ((BaseRecurrentLayer) layer).getRnnDataFormat();
} else if(layer instanceof MaskZeroLayer){
return getFormatFromRnnLayer(((MaskZeroLayer) layer).getUnderlying());
} else if(layer instanceof Bidirectional){
return getFormatFromRnnLayer(((Bidirectional) layer).getFwd());
} else if(layer instanceof LastTimeStep){
return getFormatFromRnnLayer(((LastTimeStep) layer).getUnderlying());
} else if(layer instanceof TimeDistributed){
return ((TimeDistributed) layer).getRnnDataFormat();
} else {
throw new IllegalStateException("Unable to get RNNFormat from layer of type: " + layer);
}
}
}