DL4J NWC support for RNNs (#379)
* merge conf * merge conf * conf fix * NWC initial * revert pom.xml * revert pom.xml * default NCW * bidirectional+some tests * RNNOutputLayer, RNNLossLayer, Graves + tests * rnn tests * LastTimeStep + tests * masking + tests * graves, rnnoutput, rnnloss * nwc timeseries reverse * more tests * bi-gravelstm test * fixes * rnn df tests basic * bug fix: cudnn fallback * bug fix * misc * gravelstm tests * preprocessor fixes * TimeDistributed * more tests * RnnLossLayer builder def val * copyright headers * Remove debug println Signed-off-by: Alex Black <blacka101@gmail.com> * Small fix + test naming Signed-off-by: Alex Black <blacka101@gmail.com> * Parameterized test name Signed-off-by: Alex Black <blacka101@gmail.com> * fix LastTimeStep masked * Fix MaskZero mask datatype issue Signed-off-by: Alex Black <blacka101@gmail.com> * rem println * javadoc * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
032b97912e
commit
2ecabde500
|
@ -338,7 +338,7 @@ public class RnnGradientChecks extends BaseDL4JTest {
|
|||
.weightInit(WeightInit.XAVIER)
|
||||
.list()
|
||||
.layer(new LSTM.Builder().nOut(layerSize).build())
|
||||
.layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build(), 2))
|
||||
.layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build()))
|
||||
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.setInputType(InputType.recurrent(nIn))
|
||||
|
|
|
@ -819,7 +819,7 @@ public class DTypeTests extends BaseDL4JTest {
|
|||
.layer(new DenseLayer.Builder().nOut(5).build())
|
||||
.layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())
|
||||
.layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()))
|
||||
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build(), 2))
|
||||
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build()))
|
||||
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build())
|
||||
.layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build())
|
||||
.layer(secondLast)
|
||||
|
|
|
@ -24,10 +24,7 @@ import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
|
|||
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
|
||||
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
|
||||
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||
import org.deeplearning4j.nn.conf.*;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||
|
@ -45,6 +42,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
|||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -61,12 +60,22 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
|
||||
import static org.deeplearning4j.nn.conf.RNNFormat.NCW;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
@Slf4j
|
||||
@RunWith(Parameterized.class)
|
||||
public class BidirectionalTest extends BaseDL4JTest {
|
||||
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public BidirectionalTest(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Parameterized.Parameters
|
||||
public static Object[] params(){
|
||||
return RNNFormat.values();
|
||||
}
|
||||
@Test
|
||||
public void compareImplementations(){
|
||||
for(WorkspaceMode wsm : WorkspaceMode.values()) {
|
||||
|
@ -82,9 +91,9 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
.inferenceWorkspaceMode(wsm)
|
||||
.updater(new Adam())
|
||||
.list()
|
||||
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
|
||||
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
|
||||
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
|
||||
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
|
||||
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
|
||||
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
|
||||
.nIn(10).nOut(10).build())
|
||||
.build();
|
||||
|
||||
|
@ -95,9 +104,9 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
.inferenceWorkspaceMode(wsm)
|
||||
.updater(new Adam())
|
||||
.list()
|
||||
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build())
|
||||
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build())
|
||||
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
|
||||
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
|
||||
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
|
||||
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
|
||||
.nIn(10).nOut(10).build())
|
||||
.build();
|
||||
|
||||
|
@ -116,15 +125,24 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
|
||||
net2.setParams(net1.params()); //Assuming exact same layout here...
|
||||
|
||||
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
|
||||
INDArray in;
|
||||
if (rnnDataFormat == NCW){
|
||||
in = Nd4j.rand(new int[]{3, 10, 5});
|
||||
}else{
|
||||
in = Nd4j.rand(new int[]{3, 5, 10});
|
||||
}
|
||||
|
||||
INDArray out1 = net1.output(in);
|
||||
INDArray out2 = net2.output(in);
|
||||
|
||||
assertEquals(out1, out2);
|
||||
|
||||
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
|
||||
|
||||
INDArray labels;
|
||||
if (rnnDataFormat == NCW){
|
||||
labels = Nd4j.rand(new int[]{3, 10, 5});
|
||||
}else{
|
||||
labels = Nd4j.rand(new int[]{3, 5, 10});
|
||||
}
|
||||
net1.setInput(in);
|
||||
net1.setLabels(labels);
|
||||
|
||||
|
@ -276,17 +294,22 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
.inferenceWorkspaceMode(wsm)
|
||||
.updater(new Adam())
|
||||
.list()
|
||||
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
|
||||
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
|
||||
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
|
||||
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
|
||||
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
|
||||
.nIn(10).nOut(10).build())
|
||||
.nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
|
||||
net1.init();
|
||||
|
||||
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
|
||||
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
|
||||
INDArray in;
|
||||
INDArray labels;
|
||||
|
||||
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 5} : new long[]{3, 5, 10};
|
||||
|
||||
in = Nd4j.rand(inshape);
|
||||
labels = Nd4j.rand(inshape);
|
||||
|
||||
net1.fit(in, labels);
|
||||
|
||||
|
@ -300,8 +323,8 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
MultiLayerNetwork net2 = ModelSerializer.restoreMultiLayerNetwork(new ByteArrayInputStream(bytes), true);
|
||||
|
||||
|
||||
in = Nd4j.rand(new int[]{3, 10, 5});
|
||||
labels = Nd4j.rand(new int[]{3, 10, 5});
|
||||
in = Nd4j.rand(inshape);
|
||||
labels = Nd4j.rand(inshape);
|
||||
|
||||
INDArray out1 = net1.output(in);
|
||||
INDArray out2 = net2.output(in);
|
||||
|
@ -338,18 +361,18 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
.updater(new Adam())
|
||||
.graphBuilder()
|
||||
.addInputs("in")
|
||||
.layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "in")
|
||||
.layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "0")
|
||||
.layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
|
||||
.layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in")
|
||||
.layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "0")
|
||||
.layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
|
||||
.nIn(10).nOut(10).build(), "1")
|
||||
.setOutputs("2")
|
||||
.build();
|
||||
|
||||
ComputationGraph net1 = new ComputationGraph(conf1);
|
||||
net1.init();
|
||||
|
||||
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
|
||||
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
|
||||
long[] inshape = (rnnDataFormat == NCW)? new long[]{3, 10, 5}: new long[]{3, 5, 10};
|
||||
INDArray in = Nd4j.rand(inshape);
|
||||
INDArray labels = Nd4j.rand(inshape);
|
||||
|
||||
net1.fit(new DataSet(in, labels));
|
||||
|
||||
|
@ -363,8 +386,8 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true);
|
||||
|
||||
|
||||
in = Nd4j.rand(new int[]{3, 10, 5});
|
||||
labels = Nd4j.rand(new int[]{3, 10, 5});
|
||||
in = Nd4j.rand(inshape);
|
||||
labels = Nd4j.rand(inshape);
|
||||
|
||||
INDArray out1 = net1.outputSingle(in);
|
||||
INDArray out2 = net2.outputSingle(in);
|
||||
|
@ -394,8 +417,8 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
Bidirectional.Mode[] modes = new Bidirectional.Mode[]{Bidirectional.Mode.CONCAT, Bidirectional.Mode.ADD,
|
||||
Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL};
|
||||
|
||||
|
||||
INDArray in = Nd4j.rand(new int[]{3, 10, 6});
|
||||
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10};
|
||||
INDArray in = Nd4j.rand(inshape);
|
||||
|
||||
for (Bidirectional.Mode m : modes) {
|
||||
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
|
||||
|
@ -406,7 +429,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
.inferenceWorkspaceMode(wsm)
|
||||
.updater(new Adam())
|
||||
.list()
|
||||
.layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build()))
|
||||
.layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
|
||||
|
@ -418,7 +441,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
.weightInit(WeightInit.XAVIER)
|
||||
.updater(new Adam())
|
||||
.list()
|
||||
.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build())
|
||||
.layer(new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2.clone());
|
||||
|
@ -434,11 +457,10 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
net3.setParam("0_RW", net1.getParam("0_bRW"));
|
||||
net3.setParam("0_b", net1.getParam("0_bb"));
|
||||
|
||||
INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
|
||||
|
||||
INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat);
|
||||
INDArray out1 = net1.output(in);
|
||||
INDArray out2 = net2.output(in);
|
||||
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
|
||||
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat);
|
||||
|
||||
INDArray outExp;
|
||||
switch (m) {
|
||||
|
@ -452,7 +474,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
outExp = out2.add(out3).muli(0.5);
|
||||
break;
|
||||
case CONCAT:
|
||||
outExp = Nd4j.concat(1, out2, out3);
|
||||
outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3);
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException();
|
||||
|
@ -464,25 +486,25 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
//Check gradients:
|
||||
if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) {
|
||||
|
||||
INDArray eps = Nd4j.rand(new int[]{3, 10, 6});
|
||||
INDArray eps = Nd4j.rand(inshape);
|
||||
|
||||
INDArray eps1;
|
||||
if (m == Bidirectional.Mode.CONCAT) {
|
||||
eps1 = Nd4j.concat(1, eps, eps);
|
||||
eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps);
|
||||
} else {
|
||||
eps1 = eps;
|
||||
}
|
||||
|
||||
net1.setInput(in);
|
||||
net2.setInput(in);
|
||||
net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
|
||||
net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat));
|
||||
net1.feedForward(true, false);
|
||||
net2.feedForward(true, false);
|
||||
net3.feedForward(true, false);
|
||||
|
||||
Pair<Gradient, INDArray> p1 = net1.backpropGradient(eps1, LayerWorkspaceMgr.noWorkspaces());
|
||||
Pair<Gradient, INDArray> p2 = net2.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
|
||||
Pair<Gradient, INDArray> p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT), LayerWorkspaceMgr.noWorkspaces());
|
||||
Pair<Gradient, INDArray> p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat), LayerWorkspaceMgr.noWorkspaces());
|
||||
Gradient g1 = p1.getFirst();
|
||||
Gradient g2 = p2.getFirst();
|
||||
Gradient g3 = p3.getFirst();
|
||||
|
@ -520,7 +542,9 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL};
|
||||
|
||||
|
||||
INDArray in = Nd4j.rand(new int[]{3, 10, 6});
|
||||
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10};
|
||||
INDArray in = Nd4j.rand(inshape);
|
||||
|
||||
|
||||
for (Bidirectional.Mode m : modes) {
|
||||
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
|
||||
|
@ -532,7 +556,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
.updater(new Adam())
|
||||
.graphBuilder()
|
||||
.addInputs("in")
|
||||
.layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build()), "in")
|
||||
.layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in")
|
||||
.setOutputs("0")
|
||||
.build();
|
||||
|
||||
|
@ -546,7 +570,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
.updater(new Adam())
|
||||
.graphBuilder()
|
||||
.addInputs("in")
|
||||
.layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "in")
|
||||
.layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build(), "in")
|
||||
.setOutputs("0")
|
||||
.build();
|
||||
|
||||
|
@ -566,9 +590,20 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
|
||||
INDArray out1 = net1.outputSingle(in);
|
||||
INDArray out2 = net2.outputSingle(in);
|
||||
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.outputSingle(
|
||||
TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)),
|
||||
LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
|
||||
INDArray out3;
|
||||
INDArray inReverse;
|
||||
if (rnnDataFormat == RNNFormat.NWC){
|
||||
inReverse = TimeSeriesUtils.reverseTimeSeries(in.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1);
|
||||
out3 = net3.outputSingle(inReverse);
|
||||
out3 = TimeSeriesUtils.reverseTimeSeries(out3.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1);
|
||||
|
||||
}
|
||||
else{
|
||||
inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
|
||||
out3 = net3.outputSingle(inReverse);
|
||||
out3 = TimeSeriesUtils.reverseTimeSeries(out3, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
|
||||
|
||||
}
|
||||
|
||||
INDArray outExp;
|
||||
switch (m) {
|
||||
|
@ -582,7 +617,9 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
outExp = out2.add(out3).muli(0.5);
|
||||
break;
|
||||
case CONCAT:
|
||||
outExp = Nd4j.concat(1, out2, out3);
|
||||
System.out.println(out2.shapeInfoToString());
|
||||
System.out.println(out3.shapeInfoToString());
|
||||
outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3);
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException();
|
||||
|
@ -594,22 +631,26 @@ public class BidirectionalTest extends BaseDL4JTest {
|
|||
//Check gradients:
|
||||
if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) {
|
||||
|
||||
INDArray eps = Nd4j.rand(new int[]{3, 10, 6});
|
||||
INDArray eps = Nd4j.rand(inshape);
|
||||
|
||||
INDArray eps1;
|
||||
if (m == Bidirectional.Mode.CONCAT) {
|
||||
eps1 = Nd4j.concat(1, eps, eps);
|
||||
eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps);
|
||||
} else {
|
||||
eps1 = eps;
|
||||
}
|
||||
|
||||
INDArray epsReversed = (rnnDataFormat == NCW)?
|
||||
TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT):
|
||||
TimeSeriesUtils.reverseTimeSeries(eps.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)
|
||||
.permute(0, 2, 1);
|
||||
net1.outputSingle(true, false, in);
|
||||
net2.outputSingle(true, false, in);
|
||||
net3.outputSingle(true, false, TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
|
||||
net3.outputSingle(true, false, inReverse);
|
||||
|
||||
Gradient g1 = net1.backpropGradient(eps1);
|
||||
Gradient g2 = net2.backpropGradient(eps);
|
||||
Gradient g3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
|
||||
Gradient g3 = net3.backpropGradient(epsReversed);
|
||||
|
||||
for (boolean updates : new boolean[]{false, true}) {
|
||||
if (updates) {
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
|||
import org.deeplearning4j.nn.conf.CacheMode;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
|
@ -31,6 +32,8 @@ import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
|
|||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -42,10 +45,18 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||
private double score = 0.0;
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Parameterized.Parameters
|
||||
public static Object[] params(){
|
||||
return RNNFormat.values();
|
||||
}
|
||||
@Test
|
||||
public void testBidirectionalLSTMGravesForwardBasic() {
|
||||
//Very basic test of forward prop. of LSTM layer with a time series.
|
||||
|
@ -55,7 +66,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
|
||||
final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
|
||||
.nOut(nHiddenUnits).activation(Activation.TANH).build())
|
||||
.nOut(nHiddenUnits).dataFormat(rnnDataFormat).activation(Activation.TANH).build())
|
||||
.build();
|
||||
|
||||
val numParams = conf.getLayer().initializer().numParams(conf);
|
||||
|
@ -65,22 +76,41 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
|
||||
//Data: has shape [miniBatchSize,nIn,timeSeriesLength];
|
||||
//Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength];
|
||||
if (rnnDataFormat == RNNFormat.NCW){
|
||||
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1);
|
||||
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1});
|
||||
|
||||
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1);
|
||||
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1});
|
||||
final INDArray dataMultiExampleLength1 = Nd4j.ones(10, nIn, 1);
|
||||
final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertArrayEquals(activations2.shape(), new long[] {10, nHiddenUnits, 1});
|
||||
|
||||
final INDArray dataMultiExampleLength1 = Nd4j.ones(10, nIn, 1);
|
||||
final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertArrayEquals(activations2.shape(), new long[] {10, nHiddenUnits, 1});
|
||||
final INDArray dataSingleExampleLength12 = Nd4j.ones(1, nIn, 12);
|
||||
final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertArrayEquals(activations3.shape(), new long[] {1, nHiddenUnits, 12});
|
||||
|
||||
final INDArray dataSingleExampleLength12 = Nd4j.ones(1, nIn, 12);
|
||||
final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertArrayEquals(activations3.shape(), new long[] {1, nHiddenUnits, 12});
|
||||
final INDArray dataMultiExampleLength15 = Nd4j.ones(10, nIn, 15);
|
||||
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15});
|
||||
}
|
||||
else{
|
||||
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, 1, nIn);
|
||||
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertArrayEquals(activations1.shape(), new long[] {1, 1, nHiddenUnits});
|
||||
|
||||
final INDArray dataMultiExampleLength1 = Nd4j.ones(10, 1, nIn);
|
||||
final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertArrayEquals(activations2.shape(), new long[] {10, 1, nHiddenUnits});
|
||||
|
||||
final INDArray dataSingleExampleLength12 = Nd4j.ones(1, 12, nIn);
|
||||
final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertArrayEquals(activations3.shape(), new long[] {1, 12, nHiddenUnits});
|
||||
|
||||
final INDArray dataMultiExampleLength15 = Nd4j.ones(10, 15, nIn);
|
||||
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertArrayEquals(activations4.shape(), new long[] {10, 15, nHiddenUnits});
|
||||
}
|
||||
|
||||
final INDArray dataMultiExampleLength15 = Nd4j.ones(10, nIn, 15);
|
||||
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15});
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -94,14 +124,15 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
testGravesBackwardBasicHelper(13, 3, 17, 1, 1); //Edge case: both miniBatchSize = 1 and timeSeriesLength = 1
|
||||
}
|
||||
|
||||
private static void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize,
|
||||
private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize,
|
||||
int timeSeriesLength) {
|
||||
|
||||
INDArray inputData = Nd4j.ones(miniBatchSize, nIn, timeSeriesLength);
|
||||
INDArray inputData = (rnnDataFormat == RNNFormat.NCW)?Nd4j.ones(miniBatchSize, nIn, timeSeriesLength):
|
||||
Nd4j.ones(miniBatchSize, timeSeriesLength, nIn);
|
||||
|
||||
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
|
||||
.nOut(lstmNHiddenUnits)
|
||||
.nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat)
|
||||
.dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build())
|
||||
.build();
|
||||
|
||||
|
@ -114,7 +145,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertNotNull(lstm.input());
|
||||
|
||||
INDArray epsilon = Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength);
|
||||
INDArray epsilon =(rnnDataFormat == RNNFormat.NCW)? Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength):
|
||||
Nd4j.ones(miniBatchSize, timeSeriesLength, lstmNHiddenUnits);
|
||||
|
||||
Pair<Gradient, INDArray> out = lstm.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces());
|
||||
Gradient outGradient = out.getFirst();
|
||||
|
@ -147,7 +179,11 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
assertArrayEquals(recurrentWeightGradientB.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3});
|
||||
|
||||
assertNotNull(nextEpsilon);
|
||||
assertArrayEquals(nextEpsilon.shape(), new long[] {miniBatchSize, nIn, timeSeriesLength});
|
||||
if (rnnDataFormat == RNNFormat.NCW) {
|
||||
assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, nIn, timeSeriesLength});
|
||||
}else{
|
||||
assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, timeSeriesLength, nIn });
|
||||
}
|
||||
|
||||
//Check update:
|
||||
for (String s : outGradient.gradientForVariable().keySet()) {
|
||||
|
@ -226,7 +262,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
|
||||
final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder()
|
||||
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
|
||||
.nOut(layerSize)
|
||||
.nOut(layerSize).dataFormat(rnnDataFormat)
|
||||
.dist(new UniformDistribution(-0.1, 0.1)).activation(Activation.TANH).build())
|
||||
.build();
|
||||
|
||||
|
@ -237,7 +273,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
.instantiate(confBidirectional, null, 0, params, true, params.dataType());
|
||||
|
||||
|
||||
final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength});
|
||||
final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}):
|
||||
Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn});
|
||||
|
||||
final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
|
||||
|
@ -265,13 +302,13 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
final NeuralNetConfiguration confBidirectional =
|
||||
new NeuralNetConfiguration.Builder()
|
||||
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder()
|
||||
.nIn(nIn).nOut(layerSize)
|
||||
.nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
|
||||
.dist(new UniformDistribution(-0.1, 0.1))
|
||||
.activation(Activation.TANH).updater(new NoOp()).build())
|
||||
.build();
|
||||
|
||||
final NeuralNetConfiguration confForwards = new NeuralNetConfiguration.Builder()
|
||||
.layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize)
|
||||
.layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
|
||||
.weightInit(WeightInit.ZERO).activation(Activation.TANH).build())
|
||||
.build();
|
||||
|
||||
|
@ -290,9 +327,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
Nd4j.create(1, confForwards.getLayer().initializer().numParams(confForwards)));
|
||||
|
||||
|
||||
final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength});
|
||||
final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}):
|
||||
Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn});
|
||||
final INDArray sigb = sig.dup();
|
||||
reverseColumnsInPlace(sigb.slice(0));
|
||||
|
||||
if (rnnDataFormat == RNNFormat.NCW) {
|
||||
reverseColumnsInPlace(sigb.slice(0));
|
||||
}
|
||||
else{
|
||||
reverseColumnsInPlace(sigb.slice(0).permute(1, 0));
|
||||
}
|
||||
|
||||
final INDArray recurrentWeightsF = bidirectionalLSTM
|
||||
.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS);
|
||||
|
@ -345,10 +389,14 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
|
||||
assertArrayEquals(activation1.data().asFloat(), activation2.data().asFloat(), 1e-5f);
|
||||
|
||||
final INDArray randSig = Nd4j.rand(new int[] {1, layerSize, timeSeriesLength});
|
||||
final INDArray randSigBackwards = randSig.dup();
|
||||
reverseColumnsInPlace(randSigBackwards.slice(0));
|
||||
|
||||
final INDArray randSig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {1, layerSize, timeSeriesLength}):
|
||||
Nd4j.rand(new int[] {1, timeSeriesLength, layerSize});
|
||||
INDArray randSigBackwards = randSig.dup();
|
||||
if (rnnDataFormat == RNNFormat.NCW){
|
||||
reverseColumnsInPlace(randSigBackwards.slice(0));
|
||||
}else{
|
||||
reverseColumnsInPlace(randSigBackwards.slice(0).permute(1, 0));
|
||||
}
|
||||
|
||||
final Pair<Gradient, INDArray> backprop1 = forwardsLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces());
|
||||
final Pair<Gradient, INDArray> backprop2 = bidirectionalLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces());
|
||||
|
@ -399,10 +447,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
final INDArray activation3 = bidirectionalLSTM.activate(sigb, false, LayerWorkspaceMgr.noWorkspaces()).slice(0);
|
||||
|
||||
final INDArray activation3Reverse = activation3.dup();
|
||||
reverseColumnsInPlace(activation3Reverse);
|
||||
if (rnnDataFormat == RNNFormat.NCW){
|
||||
reverseColumnsInPlace(activation3Reverse);
|
||||
}
|
||||
else{
|
||||
reverseColumnsInPlace(activation3Reverse.permute(1, 0));
|
||||
}
|
||||
|
||||
assertEquals(activation3Reverse, activation1);
|
||||
assertArrayEquals(activation3Reverse.shape(), activation1.shape());
|
||||
assertEquals(activation3Reverse, activation1);
|
||||
|
||||
|
||||
//test backprop now
|
||||
final INDArray refBackGradientReccurrent =
|
||||
|
@ -434,7 +488,12 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
final INDArray refEpsilon = backprop1.getSecond().dup();
|
||||
final INDArray backEpsilon = backprop3.getSecond().dup();
|
||||
|
||||
reverseColumnsInPlace(refEpsilon.slice(0));
|
||||
if (rnnDataFormat == RNNFormat.NCW) {
|
||||
reverseColumnsInPlace(refEpsilon.slice(0));
|
||||
}
|
||||
else{
|
||||
reverseColumnsInPlace(refEpsilon.slice(0).permute(1, 0));
|
||||
}
|
||||
assertArrayEquals(backEpsilon.dup().data().asDouble(), refEpsilon.dup().data().asDouble(), 1e-6);
|
||||
|
||||
}
|
||||
|
@ -477,10 +536,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||
.seed(12345).list()
|
||||
.layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder()
|
||||
.gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2)
|
||||
.gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat)
|
||||
.build())
|
||||
.layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder()
|
||||
.lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2)
|
||||
.lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat)
|
||||
.activation(Activation.TANH).build())
|
||||
.build();
|
||||
|
||||
|
@ -492,7 +551,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
|
||||
INDArray in = Nd4j.rand(new int[] {3, 2, 5});
|
||||
INDArray labels = Nd4j.rand(new int[] {3, 2, 5});
|
||||
|
||||
if (rnnDataFormat == RNNFormat.NWC){
|
||||
in = in.permute(0, 2, 1);
|
||||
labels = labels.permute(0, 2, 1);
|
||||
}
|
||||
net.fit(in, labels);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,11 +21,14 @@ import org.deeplearning4j.TestUtils;
|
|||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -36,9 +39,17 @@ import java.util.Collections;
|
|||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class MaskZeroLayerTest extends BaseDL4JTest {
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public MaskZeroLayerTest(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Parameterized.Parameters
|
||||
public static Object[] params(){
|
||||
return RNNFormat.values();
|
||||
}
|
||||
@Test
|
||||
public void activate() {
|
||||
|
||||
|
@ -57,7 +68,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
|
|||
.activation(Activation.IDENTITY)
|
||||
.gateActivationFunction(Activation.IDENTITY)
|
||||
.nIn(2)
|
||||
.nOut(1)
|
||||
.nOut(1).dataFormat(rnnDataFormat)
|
||||
.build();
|
||||
NeuralNetConfiguration conf = new NeuralNetConfiguration();
|
||||
conf.setLayer(underlying);
|
||||
|
@ -72,9 +83,14 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
|
|||
|
||||
MaskZeroLayer l = new MaskZeroLayer(lstm, maskingValue);
|
||||
INDArray input = Nd4j.create(Arrays.asList(ex1, ex2), new int[]{2, 2, 3});
|
||||
if (rnnDataFormat == RNNFormat.NWC){
|
||||
input = input.permute(0, 2, 1);
|
||||
}
|
||||
//WHEN
|
||||
INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces());
|
||||
|
||||
if (rnnDataFormat == RNNFormat.NWC){
|
||||
out = out.permute(0, 2,1);
|
||||
}
|
||||
//THEN output should only be incremented for the non-zero timesteps
|
||||
INDArray firstExampleOutput = out.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all());
|
||||
INDArray secondExampleOutput = out.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all());
|
||||
|
@ -94,7 +110,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
|
|||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.list()
|
||||
.layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder()
|
||||
.setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).build()).build())
|
||||
.setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
|
|
@ -0,0 +1,394 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.nn.layers.recurrent;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.TestUtils;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
@AllArgsConstructor
|
||||
public class RnnDataFormatTests extends BaseDL4JTest {
|
||||
|
||||
private boolean helpers;
|
||||
private boolean lastTimeStep;
|
||||
private boolean maskZeros;
|
||||
|
||||
@Parameterized.Parameters(name = "helpers={0},lastTimeStep={1},maskZero={2}")
|
||||
public static List params(){
|
||||
List<Object[]> ret = new ArrayList<>();
|
||||
for (boolean helpers: new boolean[]{true, false})
|
||||
for (boolean lastTimeStep: new boolean[]{true, false})
|
||||
for (boolean maskZero: new boolean[]{true, false})
|
||||
ret.add(new Object[]{helpers, lastTimeStep, maskZero});
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testSimpleRnn() {
|
||||
try {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
|
||||
|
||||
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getSimpleRnnNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
|
||||
.net2(getSimpleRnnNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
|
||||
.net3(getSimpleRnnNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
|
||||
.net4(getSimpleRnnNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
|
||||
.inNCW(inNCW)
|
||||
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
|
||||
.labelsNWC(labelsNWC)
|
||||
.testLayerIdx(1)
|
||||
.build();
|
||||
|
||||
TestCase.testHelper(tc);
|
||||
|
||||
|
||||
} finally {
|
||||
Nd4j.getEnvironment().allowHelpers(true);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLSTM() {
|
||||
try {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
|
||||
|
||||
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
|
||||
.net2(getLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
|
||||
.net3(getLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
|
||||
.net4(getLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
|
||||
.inNCW(inNCW)
|
||||
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
|
||||
.labelsNWC(labelsNWC)
|
||||
.testLayerIdx(1)
|
||||
.build();
|
||||
|
||||
TestCase.testHelper(tc);
|
||||
|
||||
|
||||
} finally {
|
||||
Nd4j.getEnvironment().allowHelpers(true);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testGraveLSTM() {
|
||||
try {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
|
||||
|
||||
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getGravesLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
|
||||
.net2(getGravesLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
|
||||
.net3(getGravesLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
|
||||
.net4(getGravesLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
|
||||
.inNCW(inNCW)
|
||||
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
|
||||
.labelsNWC(labelsNWC)
|
||||
.testLayerIdx(1)
|
||||
.build();
|
||||
|
||||
TestCase.testHelper(tc);
|
||||
|
||||
|
||||
} finally {
|
||||
Nd4j.getEnvironment().allowHelpers(true);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testGraveBiLSTM() {
|
||||
try {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
|
||||
|
||||
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getGravesBidirectionalLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
|
||||
.net2(getGravesBidirectionalLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
|
||||
.net3(getGravesBidirectionalLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
|
||||
.net4(getGravesBidirectionalLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
|
||||
.inNCW(inNCW)
|
||||
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
|
||||
.labelsNWC(labelsNWC)
|
||||
.testLayerIdx(1)
|
||||
.build();
|
||||
|
||||
TestCase.testHelper(tc);
|
||||
|
||||
|
||||
} finally {
|
||||
Nd4j.getEnvironment().allowHelpers(true);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private MultiLayerNetwork getGravesBidirectionalLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3)
|
||||
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
|
||||
} else {
|
||||
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
|
||||
}
|
||||
}
|
||||
private MultiLayerNetwork getGravesLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new GravesLSTM.Builder().nOut(3)
|
||||
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
|
||||
} else {
|
||||
return getNetWithLayer(new GravesLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
|
||||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new LSTM.Builder().nOut(3)
|
||||
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
|
||||
} else {
|
||||
return getNetWithLayer(new LSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
|
||||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getSimpleRnnNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new SimpleRnn.Builder().nOut(3)
|
||||
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
|
||||
} else {
|
||||
return getNetWithLayer(new SimpleRnn.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
|
||||
}
|
||||
}
|
||||
private MultiLayerNetwork getNetWithLayer(Layer layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) {
|
||||
if (maskZeros){
|
||||
layer = new MaskZeroLayer.Builder().setMaskValue(0.).setUnderlying(layer).build();
|
||||
}
|
||||
if(lastTimeStep){
|
||||
layer = new LastTimeStep(layer);
|
||||
}
|
||||
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
|
||||
.seed(12345)
|
||||
.list()
|
||||
.layer(new LSTM.Builder()
|
||||
.nIn(3)
|
||||
.activation(Activation.TANH)
|
||||
.dataFormat(format)
|
||||
.nOut(3)
|
||||
.helperAllowFallback(false)
|
||||
.build())
|
||||
.layer(layer)
|
||||
.layer(
|
||||
(lastTimeStep)?new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build():
|
||||
new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).dataFormat(format).build()
|
||||
)
|
||||
.setInputType(InputType.recurrent(3, 12, format));
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
|
||||
net.init();
|
||||
return net;
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@Builder
|
||||
private static class TestCase {
|
||||
private String msg;
|
||||
private MultiLayerNetwork net1;
|
||||
private MultiLayerNetwork net2;
|
||||
private MultiLayerNetwork net3;
|
||||
private MultiLayerNetwork net4;
|
||||
private INDArray inNCW;
|
||||
private INDArray labelsNCW;
|
||||
private INDArray labelsNWC;
|
||||
private int testLayerIdx;
|
||||
private boolean nwcOutput;
|
||||
|
||||
public static void testHelper(TestCase tc) {
|
||||
|
||||
tc.net2.params().assign(tc.net1.params());
|
||||
tc.net3.params().assign(tc.net1.params());
|
||||
tc.net4.params().assign(tc.net1.params());
|
||||
|
||||
INDArray inNCW = tc.inNCW;
|
||||
INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup();
|
||||
|
||||
INDArray l0_1 = tc.net1.feedForward(inNCW).get(tc.testLayerIdx + 1);
|
||||
INDArray l0_2 = tc.net2.feedForward(inNCW).get(tc.testLayerIdx + 1);
|
||||
INDArray l0_3 = tc.net3.feedForward(inNWC).get(tc.testLayerIdx + 1);
|
||||
INDArray l0_4 = tc.net4.feedForward(inNWC).get(tc.testLayerIdx + 1);
|
||||
|
||||
boolean rank3Out = tc.labelsNCW.rank() == 3;
|
||||
assertEquals(tc.msg, l0_1, l0_2);
|
||||
if (rank3Out){
|
||||
assertEquals(tc.msg, l0_1, l0_3.permute(0, 2, 1));
|
||||
assertEquals(tc.msg, l0_1, l0_4.permute(0, 2, 1));
|
||||
}
|
||||
else{
|
||||
assertEquals(tc.msg, l0_1, l0_3);
|
||||
assertEquals(tc.msg, l0_1, l0_4);
|
||||
}
|
||||
INDArray out1 = tc.net1.output(inNCW);
|
||||
INDArray out2 = tc.net2.output(inNCW);
|
||||
INDArray out3 = tc.net3.output(inNWC);
|
||||
INDArray out4 = tc.net4.output(inNWC);
|
||||
|
||||
assertEquals(tc.msg, out1, out2);
|
||||
if (rank3Out){
|
||||
assertEquals(tc.msg, out1, out3.permute(0, 2, 1)); //NWC to NCW
|
||||
assertEquals(tc.msg, out1, out4.permute(0, 2, 1));
|
||||
}
|
||||
else{
|
||||
assertEquals(tc.msg, out1, out3); //NWC to NCW
|
||||
assertEquals(tc.msg, out1, out4);
|
||||
}
|
||||
|
||||
|
||||
//Test backprop
|
||||
Pair<Gradient, INDArray> p1 = tc.net1.calculateGradients(inNCW, tc.labelsNCW, null, null);
|
||||
Pair<Gradient, INDArray> p2 = tc.net2.calculateGradients(inNCW, tc.labelsNCW, null, null);
|
||||
Pair<Gradient, INDArray> p3 = tc.net3.calculateGradients(inNWC, tc.labelsNWC, null, null);
|
||||
Pair<Gradient, INDArray> p4 = tc.net4.calculateGradients(inNWC, tc.labelsNWC, null, null);
|
||||
|
||||
//Inpput gradients
|
||||
assertEquals(tc.msg, p1.getSecond(), p2.getSecond());
|
||||
|
||||
assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0, 2, 1)); //Input gradients for NWC input are also in NWC format
|
||||
assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0, 2, 1));
|
||||
|
||||
|
||||
List<String> diff12 = differentGrads(p1.getFirst(), p2.getFirst());
|
||||
List<String> diff13 = differentGrads(p1.getFirst(), p3.getFirst());
|
||||
List<String> diff14 = differentGrads(p1.getFirst(), p4.getFirst());
|
||||
assertEquals(tc.msg + " " + diff12, 0, diff12.size());
|
||||
assertEquals(tc.msg + " " + diff13, 0, diff13.size());
|
||||
assertEquals(tc.msg + " " + diff14, 0, diff14.size());
|
||||
|
||||
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable());
|
||||
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable());
|
||||
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable());
|
||||
|
||||
tc.net1.fit(inNCW, tc.labelsNCW);
|
||||
tc.net2.fit(inNCW, tc.labelsNCW);
|
||||
tc.net3.fit(inNWC, tc.labelsNWC);
|
||||
tc.net4.fit(inNWC, tc.labelsNWC);
|
||||
|
||||
assertEquals(tc.msg, tc.net1.params(), tc.net2.params());
|
||||
assertEquals(tc.msg, tc.net1.params(), tc.net3.params());
|
||||
assertEquals(tc.msg, tc.net1.params(), tc.net4.params());
|
||||
|
||||
//Test serialization
|
||||
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
|
||||
MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2);
|
||||
MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3);
|
||||
MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4);
|
||||
|
||||
out1 = tc.net1.output(inNCW);
|
||||
assertEquals(tc.msg, out1, net1a.output(inNCW));
|
||||
assertEquals(tc.msg, out1, net2a.output(inNCW));
|
||||
|
||||
if (rank3Out) {
|
||||
assertEquals(tc.msg, out1, net3a.output(inNWC).permute(0, 2, 1)); //NWC to NCW
|
||||
assertEquals(tc.msg, out1, net4a.output(inNWC).permute(0, 2, 1));
|
||||
}
|
||||
else{
|
||||
assertEquals(tc.msg, out1, net3a.output(inNWC)); //NWC to NCW
|
||||
assertEquals(tc.msg, out1, net4a.output(inNWC));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
private static List<String> differentGrads(Gradient g1, Gradient g2){
|
||||
List<String> differs = new ArrayList<>();
|
||||
Map<String,INDArray> m1 = g1.gradientForVariable();
|
||||
Map<String,INDArray> m2 = g2.gradientForVariable();
|
||||
for(String s : m1.keySet()){
|
||||
INDArray a1 = m1.get(s);
|
||||
INDArray a2 = m2.get(s);
|
||||
if(!a1.equals(a2)){
|
||||
differs.add(s);
|
||||
}
|
||||
}
|
||||
return differs;
|
||||
}
|
||||
}
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.TestUtils;
|
|||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||
|
@ -29,6 +30,8 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
|||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
|
@ -42,14 +45,25 @@ import static org.nd4j.linalg.activations.Activation.IDENTITY;
|
|||
import static org.nd4j.linalg.activations.Activation.TANH;
|
||||
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE;
|
||||
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestLastTimeStepLayer extends BaseDL4JTest {
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public TestLastTimeStepLayer(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Parameterized.Parameters(name="{0}")
|
||||
public static Object[] params(){
|
||||
return RNNFormat.values();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLastTimeStepVertex() {
|
||||
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
|
||||
.addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder()
|
||||
.nIn(5).nOut(6).build()), "in")
|
||||
.nIn(5).nOut(6).dataFormat(rnnDataFormat).build()), "in")
|
||||
.setOutputs("lastTS")
|
||||
.build();
|
||||
|
||||
|
@ -59,9 +73,22 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
|
|||
//First: test without input mask array
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
Layer l = graph.getLayer("lastTS");
|
||||
INDArray in = Nd4j.rand(new int[]{3, 5, 6});
|
||||
INDArray in;
|
||||
if (rnnDataFormat == RNNFormat.NCW){
|
||||
in = Nd4j.rand(3, 5, 6);
|
||||
}
|
||||
else{
|
||||
in = Nd4j.rand(3, 6, 5);
|
||||
}
|
||||
INDArray outUnderlying = ((LastTimeStepLayer)l).getUnderlying().activate(in, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
INDArray expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5));
|
||||
INDArray expOut;
|
||||
if (rnnDataFormat == RNNFormat.NCW){
|
||||
expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5));
|
||||
}
|
||||
else{
|
||||
expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.point(5), NDArrayIndex.all());
|
||||
}
|
||||
|
||||
|
||||
|
||||
//Forward pass:
|
||||
|
@ -76,9 +103,17 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
|
|||
graph.setLayerMaskArrays(new INDArray[]{inMask}, null);
|
||||
|
||||
expOut = Nd4j.zeros(3, 6);
|
||||
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2)));
|
||||
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3)));
|
||||
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4)));
|
||||
if (rnnDataFormat == RNNFormat.NCW){
|
||||
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2)));
|
||||
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3)));
|
||||
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4)));
|
||||
}
|
||||
else{
|
||||
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.point(2), NDArrayIndex.all()));
|
||||
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.point(3), NDArrayIndex.all()));
|
||||
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.point(4), NDArrayIndex.all()));
|
||||
}
|
||||
|
||||
|
||||
outFwd = l.activate(in, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertEquals(expOut, outFwd);
|
||||
|
@ -97,9 +132,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
|
|||
.seed(1234)
|
||||
.graphBuilder()
|
||||
.addInputs("in")
|
||||
.setInputTypes(InputType.recurrent(1))
|
||||
.setInputTypes(InputType.recurrent(1, rnnDataFormat))
|
||||
.addLayer("RNN", new LastTimeStep(new LSTM.Builder()
|
||||
.nOut(10)
|
||||
.nOut(10).dataFormat(rnnDataFormat)
|
||||
.build()), "in")
|
||||
.addLayer("dense", new DenseLayer.Builder()
|
||||
.nOut(10)
|
||||
|
@ -120,7 +155,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
|
|||
INDArray fm2 = Nd4j.zeros(1,24);
|
||||
INDArray fm3 = Nd4j.zeros(1,24);
|
||||
fm3.get(NDArrayIndex.point(0), NDArrayIndex.interval(0,5)).assign(1);
|
||||
|
||||
if (rnnDataFormat == RNNFormat.NWC){
|
||||
f = f.permute(0, 2, 1);
|
||||
}
|
||||
INDArray[] out1 = cg.output(false, new INDArray[]{f}, new INDArray[]{fm1});
|
||||
try {
|
||||
cg.output(false, new INDArray[]{f}, new INDArray[]{fm2});
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
|||
import org.deeplearning4j.TestUtils;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.dropout.TestDropout;
|
||||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||
|
@ -31,6 +32,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -46,8 +49,18 @@ import static org.junit.Assert.assertEquals;
|
|||
import static org.junit.Assert.assertNotEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestRnnLayers extends BaseDL4JTest {
|
||||
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public TestRnnLayers(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Parameterized.Parameters
|
||||
public static Object[] params(){
|
||||
return RNNFormat.values();
|
||||
}
|
||||
@Test
|
||||
public void testTimeStepIs3Dimensional() {
|
||||
|
||||
|
@ -58,8 +71,8 @@ public class TestRnnLayers extends BaseDL4JTest {
|
|||
.updater(new NoOp())
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.list()
|
||||
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).build())
|
||||
.layer(new LSTM.Builder().nIn(3).nOut(5).build())
|
||||
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).dataFormat(rnnDataFormat).build())
|
||||
.layer(new LSTM.Builder().nIn(3).nOut(5).dataFormat(rnnDataFormat).build())
|
||||
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).build())
|
||||
.build();
|
||||
|
||||
|
@ -70,9 +83,9 @@ public class TestRnnLayers extends BaseDL4JTest {
|
|||
org.deeplearning4j.nn.layers.recurrent.SimpleRnn simpleRnn =
|
||||
(org.deeplearning4j.nn.layers.recurrent.SimpleRnn) net.getLayer(0);
|
||||
|
||||
INDArray rnnInput3d = Nd4j.create(10, 12, 1);
|
||||
INDArray rnnInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10,12, 1):Nd4j.create(10, 1, 12);
|
||||
INDArray simpleOut = simpleRnn.rnnTimeStep(rnnInput3d, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertTrue(Arrays.equals(simpleOut.shape(), new long[] {10, 3, 1}));
|
||||
assertTrue(Arrays.equals(simpleOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 3, 1}:new long[]{10, 1, 3}));
|
||||
|
||||
INDArray rnnInput2d = Nd4j.create(10, 12);
|
||||
try {
|
||||
|
@ -84,9 +97,9 @@ public class TestRnnLayers extends BaseDL4JTest {
|
|||
org.deeplearning4j.nn.layers.recurrent.LSTM lstm =
|
||||
(org.deeplearning4j.nn.layers.recurrent.LSTM) net.getLayer(1);
|
||||
|
||||
INDArray lstmInput3d = Nd4j.create(10, 3, 1);
|
||||
INDArray lstmInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10, 3, 1):Nd4j.create(10, 1, 3);
|
||||
INDArray lstmOut = lstm.rnnTimeStep(lstmInput3d, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertTrue(Arrays.equals(lstmOut.shape(), new long[] {10, 5, 1}));
|
||||
assertTrue(Arrays.equals(lstmOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 5, 1}:new long[]{10, 1, 5}));
|
||||
|
||||
INDArray lstmInput2d = Nd4j.create(10, 3);
|
||||
try {
|
||||
|
@ -112,19 +125,19 @@ public class TestRnnLayers extends BaseDL4JTest {
|
|||
TestDropout.CustomDropout cd = new TestDropout.CustomDropout();
|
||||
switch (s){
|
||||
case "graves":
|
||||
layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
|
||||
layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
|
||||
layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
|
||||
layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||
layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||
layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||
break;
|
||||
case "lstm":
|
||||
layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
|
||||
layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
|
||||
layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
|
||||
layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||
layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||
layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||
break;
|
||||
case "simple":
|
||||
layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
|
||||
layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
|
||||
layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
|
||||
layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||
layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||
layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException(s);
|
||||
|
@ -134,21 +147,21 @@ public class TestRnnLayers extends BaseDL4JTest {
|
|||
.seed(12345)
|
||||
.list()
|
||||
.layer(layer)
|
||||
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
|
||||
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
|
||||
.build();
|
||||
|
||||
MultiLayerConfiguration confD = new NeuralNetConfiguration.Builder()
|
||||
.seed(12345)
|
||||
.list()
|
||||
.layer(layerD)
|
||||
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
|
||||
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
|
||||
.build();
|
||||
|
||||
MultiLayerConfiguration confD2 = new NeuralNetConfiguration.Builder()
|
||||
.seed(12345)
|
||||
.list()
|
||||
.layer(layerD2)
|
||||
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
|
||||
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
|
@ -178,7 +191,6 @@ public class TestRnnLayers extends BaseDL4JTest {
|
|||
assertNotEquals(s, out2, out2D);
|
||||
|
||||
INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345);
|
||||
|
||||
net.fit(f.dup(), l);
|
||||
netD.fit(f.dup(), l);
|
||||
assertNotEquals(s, net.params(), netD.params());
|
||||
|
@ -209,10 +221,10 @@ public class TestRnnLayers extends BaseDL4JTest {
|
|||
|
||||
switch (i){
|
||||
case 0:
|
||||
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).build());
|
||||
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).dataFormat(rnnDataFormat).build());
|
||||
break;
|
||||
case 1:
|
||||
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build());
|
||||
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).dataFormat(rnnDataFormat).build());
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException();
|
||||
|
@ -224,13 +236,16 @@ public class TestRnnLayers extends BaseDL4JTest {
|
|||
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
|
||||
INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10);
|
||||
|
||||
if (rnnDataFormat == RNNFormat.NWC){
|
||||
l = l.permute(0, 2, 1);
|
||||
}
|
||||
try{
|
||||
net.fit(in,l);
|
||||
} catch (Throwable t){
|
||||
String msg = t.getMessage();
|
||||
if(msg == null)
|
||||
t.printStackTrace();
|
||||
System.out.println(i);
|
||||
assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
|
||||
}
|
||||
|
||||
|
|
|
@ -20,10 +20,13 @@ import org.deeplearning4j.BaseDL4JTest;
|
|||
import org.deeplearning4j.TestUtils;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -36,8 +39,18 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.all;
|
|||
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
||||
import static org.nd4j.linalg.indexing.NDArrayIndex.point;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestSimpleRnn extends BaseDL4JTest {
|
||||
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public TestSimpleRnn(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Parameterized.Parameters
|
||||
public static Object[] params(){
|
||||
return RNNFormat.values();
|
||||
}
|
||||
@Test
|
||||
public void testSimpleRnn(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -46,7 +59,15 @@ public class TestSimpleRnn extends BaseDL4JTest {
|
|||
int nIn = 5;
|
||||
int layerSize = 6;
|
||||
int tsLength = 7;
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength});
|
||||
INDArray in;
|
||||
if (rnnDataFormat == RNNFormat.NCW){
|
||||
in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength});
|
||||
}
|
||||
else{
|
||||
in = Nd4j.rand(DataType.FLOAT, new int[]{m, tsLength, nIn});
|
||||
}
|
||||
|
||||
|
||||
// in.get(all(), all(), interval(1,tsLength)).assign(0);
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
|
@ -54,7 +75,7 @@ public class TestSimpleRnn extends BaseDL4JTest {
|
|||
.weightInit(WeightInit.XAVIER)
|
||||
.activation(Activation.TANH)
|
||||
.list()
|
||||
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).build())
|
||||
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
|
@ -68,7 +89,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
|
|||
|
||||
INDArray outLast = null;
|
||||
for( int i=0; i<tsLength; i++ ){
|
||||
INDArray inCurrent = in.get(all(), all(), point(i));
|
||||
INDArray inCurrent;
|
||||
if (rnnDataFormat == RNNFormat.NCW){
|
||||
inCurrent = in.get(all(), all(), point(i));
|
||||
}
|
||||
else{
|
||||
inCurrent = in.get(all(), point(i), all());
|
||||
}
|
||||
|
||||
INDArray outExpCurrent = inCurrent.mmul(w);
|
||||
if(outLast != null){
|
||||
|
@ -79,7 +106,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
|
|||
|
||||
Transforms.tanh(outExpCurrent, false);
|
||||
|
||||
INDArray outActCurrent = out.get(all(), all(), point(i));
|
||||
INDArray outActCurrent;
|
||||
if (rnnDataFormat == RNNFormat.NCW){
|
||||
outActCurrent = out.get(all(), all(), point(i));
|
||||
}
|
||||
else{
|
||||
outActCurrent = out.get(all(), point(i), all());
|
||||
}
|
||||
assertEquals(String.valueOf(i), outExpCurrent, outActCurrent);
|
||||
|
||||
outLast = outExpCurrent;
|
||||
|
@ -100,7 +133,7 @@ public class TestSimpleRnn extends BaseDL4JTest {
|
|||
.weightInit(WeightInit.XAVIER)
|
||||
.activation(Activation.TANH)
|
||||
.list()
|
||||
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize)
|
||||
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
|
||||
.biasInit(100)
|
||||
.build())
|
||||
.build();
|
||||
|
|
|
@ -4,6 +4,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
|||
import org.deeplearning4j.TestUtils;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
|
@ -12,6 +13,8 @@ import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
|||
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -22,8 +25,18 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestTimeDistributed extends BaseDL4JTest {
|
||||
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public TestTimeDistributed(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Parameterized.Parameters
|
||||
public static Object[] params(){
|
||||
return RNNFormat.values();
|
||||
}
|
||||
@Test
|
||||
public void testTimeDistributed(){
|
||||
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
|
||||
|
@ -34,11 +47,11 @@ public class TestTimeDistributed extends BaseDL4JTest {
|
|||
.seed(12345)
|
||||
.updater(new Adam(0.1))
|
||||
.list()
|
||||
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
|
||||
.layer(new LSTM.Builder().nIn(3).nOut(3).dataFormat(rnnDataFormat).build())
|
||||
.layer(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build())
|
||||
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX)
|
||||
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.setInputType(InputType.recurrent(3))
|
||||
.setInputType(InputType.recurrent(3, rnnDataFormat))
|
||||
.build();
|
||||
|
||||
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
||||
|
@ -47,11 +60,11 @@ public class TestTimeDistributed extends BaseDL4JTest {
|
|||
.seed(12345)
|
||||
.updater(new Adam(0.1))
|
||||
.list()
|
||||
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
|
||||
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), 2))
|
||||
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX)
|
||||
.layer(new LSTM.Builder().nIn(3).nOut(3).dataFormat(rnnDataFormat).build())
|
||||
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), rnnDataFormat))
|
||||
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.setInputType(InputType.recurrent(3))
|
||||
.setInputType(InputType.recurrent(3, rnnDataFormat))
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
|
||||
|
@ -62,13 +75,21 @@ public class TestTimeDistributed extends BaseDL4JTest {
|
|||
for( int mb : new int[]{1, 5}) {
|
||||
for(char inLabelOrder : new char[]{'c', 'f'}) {
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3, 5).dup(inLabelOrder);
|
||||
|
||||
if (rnnDataFormat == RNNFormat.NWC){
|
||||
in = in.permute(0, 2, 1);
|
||||
}
|
||||
INDArray out1 = net1.output(in);
|
||||
INDArray out2 = net2.output(in);
|
||||
|
||||
assertEquals(out1, out2);
|
||||
|
||||
INDArray labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder);
|
||||
INDArray labels ;
|
||||
if (rnnDataFormat == RNNFormat.NCW) {
|
||||
labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder);
|
||||
}else{
|
||||
labels = TestUtils.randomOneHotTimeSeries(mb, 5, 3).dup(inLabelOrder);
|
||||
}
|
||||
|
||||
|
||||
|
||||
DataSet ds = new DataSet(in, labels);
|
||||
net1.fit(ds);
|
||||
|
|
|
@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||
|
@ -160,8 +161,8 @@ public class KerasConvolution1D extends KerasConvolution {
|
|||
public InputPreProcessor getInputPreprocessor(InputType... inputType) throws InvalidKerasConfigurationException {
|
||||
if (inputType.length > 1)
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"Keras LSTM layer accepts only one input (received " + inputType.length + ")");
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
|
||||
"Keras Conv1D layer accepts only one input (received " + inputType.length + ")");
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], RNNFormat.NCW,layerName);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -22,11 +22,9 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import lombok.val;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||
|
@ -37,6 +35,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||
import org.deeplearning4j.nn.params.LSTMParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -266,7 +265,8 @@ public class KerasLSTM extends KerasLayer {
|
|||
throw new InvalidKerasConfigurationException("Keras LSTM layer accepts only one single input" +
|
||||
"or three (input to LSTM and two states tensors, but " +
|
||||
"received " + inputType.length + ".");
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
|
||||
RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer);
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f,layerName);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -21,7 +21,9 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||
|
@ -36,6 +38,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
|
@ -227,7 +230,8 @@ public class KerasSimpleRnn extends KerasLayer {
|
|||
throw new InvalidKerasConfigurationException(
|
||||
"Keras SimpleRnn layer accepts only one input (received " + inputType.length + ")");
|
||||
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
|
||||
RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer);
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f, layerName);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -218,7 +218,7 @@ public class KerasBidirectional extends KerasLayer {
|
|||
if (inputType.length > 1)
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"Keras Bidirectional layer accepts only one input (received " + inputType.length + ")");
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], ((Bidirectional)layer).getRNNDataFormat(), layerName);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.nn.api.layers;
|
||||
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
@ -98,4 +99,5 @@ public interface RecurrentLayer extends Layer {
|
|||
*/
|
||||
Pair<Gradient, INDArray> tbpttBackpropGradient(INDArray epsilon, int tbpttBackLength, LayerWorkspaceMgr workspaceMgr);
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
package org.deeplearning4j.nn.conf;
|
||||
|
||||
/**
|
||||
* NCW = "channels first" - arrays of shape [minibatch, channels, width]<br>
|
||||
* NWC = "channels last" - arrays of shape [minibatch, width, channels]<br>
|
||||
* "width" corresponds to sequence length and "channels" corresponds to sequence item size.
|
||||
*/
|
||||
|
||||
public enum RNNFormat {
|
||||
NCW,
|
||||
NWC
|
||||
}
|
|
@ -20,6 +20,7 @@ import lombok.Data;
|
|||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -111,9 +112,16 @@ public abstract class InputType implements Serializable {
|
|||
* @return InputTypeRecurrent
|
||||
*/
|
||||
public static InputType recurrent(long size, long timeSeriesLength) {
|
||||
return new InputTypeRecurrent(size, timeSeriesLength);
|
||||
return new InputTypeRecurrent(size, timeSeriesLength, RNNFormat.NCW);
|
||||
}
|
||||
|
||||
public static InputType recurrent(long size, RNNFormat format){
|
||||
return new InputTypeRecurrent(size, format);
|
||||
}
|
||||
|
||||
public static InputType recurrent(long size, long timeSeriesLength, RNNFormat format){
|
||||
return new InputTypeRecurrent(size, timeSeriesLength, format);
|
||||
}
|
||||
/**
|
||||
* Input type for convolutional (CNN) data, that is 4d with shape [miniBatchSize, channels, height, width].
|
||||
* For CNN data that has been flattened, use {@link #convolutionalFlat(long, long, long)}
|
||||
|
@ -216,14 +224,23 @@ public abstract class InputType implements Serializable {
|
|||
public static class InputTypeRecurrent extends InputType {
|
||||
private long size;
|
||||
private long timeSeriesLength;
|
||||
|
||||
private RNNFormat format = RNNFormat.NCW;
|
||||
public InputTypeRecurrent(long size) {
|
||||
this(size, -1);
|
||||
}
|
||||
public InputTypeRecurrent(long size, long timeSeriesLength){
|
||||
this(size, timeSeriesLength, RNNFormat.NCW);
|
||||
}
|
||||
|
||||
public InputTypeRecurrent(@JsonProperty("size") long size, @JsonProperty("timeSeriesLength") long timeSeriesLength) {
|
||||
public InputTypeRecurrent(long size, RNNFormat format){
|
||||
this(size, -1, format);
|
||||
}
|
||||
public InputTypeRecurrent(@JsonProperty("size") long size,
|
||||
@JsonProperty("timeSeriesLength") long timeSeriesLength,
|
||||
@JsonProperty("format") RNNFormat format) {
|
||||
this.size = size;
|
||||
this.timeSeriesLength = timeSeriesLength;
|
||||
this.format = format;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -234,9 +251,9 @@ public abstract class InputType implements Serializable {
|
|||
@Override
|
||||
public String toString() {
|
||||
if (timeSeriesLength > 0) {
|
||||
return "InputTypeRecurrent(" + size + ",timeSeriesLength=" + timeSeriesLength + ")";
|
||||
return "InputTypeRecurrent(" + size + ",timeSeriesLength=" + timeSeriesLength + ",format=" + format + ")";
|
||||
} else {
|
||||
return "InputTypeRecurrent(" + size + ")";
|
||||
return "InputTypeRecurrent(" + size + ",format=" + format + ")";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -251,8 +268,23 @@ public abstract class InputType implements Serializable {
|
|||
|
||||
@Override
|
||||
public long[] getShape(boolean includeBatchDim) {
|
||||
if(includeBatchDim) return new long[]{-1, size, timeSeriesLength};
|
||||
else return new long[]{size, timeSeriesLength};
|
||||
if (includeBatchDim){
|
||||
if (format == RNNFormat.NCW){
|
||||
return new long[]{-1, size, timeSeriesLength};
|
||||
}
|
||||
else{
|
||||
return new long[]{-1, timeSeriesLength, size};
|
||||
}
|
||||
|
||||
}
|
||||
else{
|
||||
if (format == RNNFormat.NCW){
|
||||
return new long[]{size, timeSeriesLength};
|
||||
}
|
||||
else{
|
||||
return new long[]{timeSeriesLength, size};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers;
|
|||
import lombok.*;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
|
@ -35,10 +36,12 @@ import java.util.List;
|
|||
public abstract class BaseRecurrentLayer extends FeedForwardLayer {
|
||||
|
||||
protected IWeightInit weightInitFnRecurrent;
|
||||
protected RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
|
||||
protected BaseRecurrentLayer(Builder builder) {
|
||||
super(builder);
|
||||
this.weightInitFnRecurrent = builder.weightInitFnRecurrent;
|
||||
this.rnnDataFormat = builder.rnnDataFormat;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -51,7 +54,7 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
|
|||
|
||||
InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
|
||||
|
||||
return InputType.recurrent(nOut, itr.getTimeSeriesLength());
|
||||
return InputType.recurrent(nOut, itr.getTimeSeriesLength(), itr.getFormat());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -64,12 +67,13 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
|
|||
if (nIn <= 0 || override) {
|
||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||
this.nIn = r.getSize();
|
||||
this.rnnDataFormat = r.getFormat();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, rnnDataFormat,getLayerName());
|
||||
}
|
||||
|
||||
@NoArgsConstructor
|
||||
|
@ -77,6 +81,12 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
|
|||
@Setter
|
||||
public static abstract class Builder<T extends Builder<T>> extends FeedForwardLayer.Builder<T> {
|
||||
|
||||
/**
|
||||
* Set the format of data expected by the RNN. NCW = [miniBatchSize, size, timeSeriesLength],
|
||||
* NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW.
|
||||
*/
|
||||
protected RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
|
||||
/**
|
||||
* Set constraints to be applied to the RNN recurrent weight parameters of this layer. Default: no
|
||||
* constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters,
|
||||
|
@ -163,5 +173,10 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
|
|||
this.setWeightInitFnRecurrent(new WeightInitDistribution(dist));
|
||||
return (T) this;
|
||||
}
|
||||
|
||||
public T dataFormat(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
return (T)this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import lombok.NoArgsConstructor;
|
|||
import lombok.ToString;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.deeplearning4j.util.Convolution1DUtils;
|
||||
|
@ -114,7 +115,7 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
+ "\"): input is null");
|
||||
}
|
||||
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName());
|
||||
}
|
||||
|
||||
public static class Builder extends ConvolutionLayer.BaseConvBuilder<Builder> {
|
||||
|
|
|
@ -87,7 +87,7 @@ public abstract class FeedForwardLayer extends BaseLayer {
|
|||
return null;
|
||||
case RNN:
|
||||
//RNN -> FF
|
||||
return new RnnToFeedForwardPreProcessor();
|
||||
return new RnnToFeedForwardPreProcessor(((InputType.InputTypeRecurrent)inputType).getFormat());
|
||||
case CNN:
|
||||
//CNN -> FF
|
||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
|||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
|
||||
|
@ -528,7 +529,7 @@ public class InputTypeUtil {
|
|||
}
|
||||
}
|
||||
|
||||
public static InputPreProcessor getPreprocessorForInputTypeRnnLayers(InputType inputType, String layerName) {
|
||||
public static InputPreProcessor getPreprocessorForInputTypeRnnLayers(InputType inputType, RNNFormat rnnDataFormat, String layerName) {
|
||||
if (inputType == null) {
|
||||
throw new IllegalStateException(
|
||||
"Invalid input for RNN layer (layer name = \"" + layerName + "\"): input type is null");
|
||||
|
@ -539,14 +540,14 @@ public class InputTypeUtil {
|
|||
case CNNFlat:
|
||||
//FF -> RNN or CNNFlat -> RNN
|
||||
//In either case, input data format is a row vector per example
|
||||
return new FeedForwardToRnnPreProcessor();
|
||||
return new FeedForwardToRnnPreProcessor(rnnDataFormat);
|
||||
case RNN:
|
||||
//RNN -> RNN: No preprocessor necessary
|
||||
return null;
|
||||
case CNN:
|
||||
//CNN -> RNN
|
||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||
return new CnnToRnnPreProcessor(c.getHeight(), c.getWidth(), c.getChannels());
|
||||
return new CnnToRnnPreProcessor(c.getHeight(), c.getWidth(), c.getChannels(), rnnDataFormat);
|
||||
default:
|
||||
throw new RuntimeException("Unknown input type: " + inputType);
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers;
|
|||
import lombok.*;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
|
||||
|
@ -86,7 +87,7 @@ public class LearnedSelfAttentionLayer extends SameDiffLayer {
|
|||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -20,6 +20,7 @@ import lombok.*;
|
|||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
|
||||
|
@ -136,7 +137,7 @@ public class LocallyConnected1D extends SameDiffLayer {
|
|||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers;
|
|||
import lombok.*;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
|
||||
|
@ -92,7 +93,7 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
|
|||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.Layer;
|
|||
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||
|
@ -53,12 +54,13 @@ import java.util.Map;
|
|||
@ToString(callSuper = true)
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class RnnLossLayer extends FeedForwardLayer {
|
||||
|
||||
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
protected ILossFunction lossFn;
|
||||
|
||||
private RnnLossLayer(Builder builder) {
|
||||
super(builder);
|
||||
this.setLossFn(builder.lossFn);
|
||||
this.rnnDataFormat = builder.rnnDataFormat;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -91,7 +93,7 @@ public class RnnLossLayer extends FeedForwardLayer {
|
|||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -111,8 +113,9 @@ public class RnnLossLayer extends FeedForwardLayer {
|
|||
|
||||
public static class Builder extends BaseOutputLayer.Builder<Builder> {
|
||||
|
||||
public Builder() {
|
||||
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
|
||||
public Builder() {
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -153,6 +156,14 @@ public class RnnLossLayer extends FeedForwardLayer {
|
|||
"This layer has no parameters, thus nIn will always equal nOut.");
|
||||
}
|
||||
|
||||
/**
|
||||
* @param rnnDataFormat Data format expected by the layer. NCW = [miniBatchSize, size, timeSeriesLength],
|
||||
* NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW.
|
||||
*/
|
||||
public Builder dataFormat(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
return this;
|
||||
}
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public RnnLossLayer build() {
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.Layer;
|
|||
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
|
@ -51,9 +52,11 @@ import java.util.Map;
|
|||
@EqualsAndHashCode(callSuper = true)
|
||||
public class RnnOutputLayer extends BaseOutputLayer {
|
||||
|
||||
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
private RnnOutputLayer(Builder builder) {
|
||||
super(builder);
|
||||
initializeConstraints(builder);
|
||||
this.rnnDataFormat = builder.rnnDataFormat;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -85,7 +88,7 @@ public class RnnOutputLayer extends BaseOutputLayer {
|
|||
}
|
||||
InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
|
||||
|
||||
return InputType.recurrent(nOut, itr.getTimeSeriesLength());
|
||||
return InputType.recurrent(nOut, itr.getTimeSeriesLength(), itr.getFormat());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -97,18 +100,20 @@ public class RnnOutputLayer extends BaseOutputLayer {
|
|||
|
||||
if (nIn <= 0 || override) {
|
||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||
this.rnnDataFormat = r.getFormat();
|
||||
this.nIn = r.getSize();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, rnnDataFormat, getLayerName());
|
||||
}
|
||||
|
||||
|
||||
public static class Builder extends BaseOutputLayer.Builder<Builder> {
|
||||
|
||||
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
public Builder() {
|
||||
//Set default activation function to softmax (to match default loss function MCXENT)
|
||||
this.setActivationFn(new ActivationSoftmax());
|
||||
|
@ -137,5 +142,14 @@ public class RnnOutputLayer extends BaseOutputLayer {
|
|||
public RnnOutputLayer build() {
|
||||
return new RnnOutputLayer(this);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param rnnDataFormat Data format expected by the layer. NCW = [miniBatchSize, size, timeSeriesLength],
|
||||
* NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW.
|
||||
*/
|
||||
public Builder dataFormat(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.conf.layers;
|
|||
|
||||
import lombok.*;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
|
||||
|
@ -75,7 +76,7 @@ public class SelfAttentionLayer extends SameDiffLayer {
|
|||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -22,6 +22,7 @@ import lombok.NoArgsConstructor;
|
|||
import lombok.ToString;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.deeplearning4j.util.Convolution1DUtils;
|
||||
|
@ -105,7 +106,7 @@ public class Subsampling1DLayer extends SubsamplingLayer {
|
|||
+ "\"): input is null");
|
||||
}
|
||||
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -20,6 +20,7 @@ import lombok.*;
|
|||
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||
|
@ -104,7 +105,7 @@ public class ZeroPadding1DLayer extends NoParamLayer {
|
|||
+ "\"): input is null");
|
||||
}
|
||||
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.nn.api.ParamInitializer;
|
|||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||
|
@ -30,6 +31,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
|||
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
|
||||
import org.deeplearning4j.nn.params.BidirectionalParamInitializer;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.learning.config.IUpdater;
|
||||
|
@ -124,6 +126,10 @@ public class Bidirectional extends Layer {
|
|||
}
|
||||
}
|
||||
|
||||
public RNNFormat getRNNDataFormat(){
|
||||
return TimeSeriesUtils.getFormatFromRnnLayer(fwd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
|
||||
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
||||
|
@ -170,7 +176,7 @@ public class Bidirectional extends Layer {
|
|||
} else {
|
||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) outOrig;
|
||||
if (mode == Mode.CONCAT) {
|
||||
return InputType.recurrent(2 * r.getSize());
|
||||
return InputType.recurrent(2 * r.getSize(), getRNNDataFormat());
|
||||
} else {
|
||||
return r;
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.NonNull;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||
|
@ -29,17 +30,19 @@ import java.util.Collection;
|
|||
@EqualsAndHashCode(callSuper = true)
|
||||
public class TimeDistributed extends BaseWrapperLayer {
|
||||
|
||||
private final int timeAxis;
|
||||
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
|
||||
/**
|
||||
* @param underlying Underlying (internal) layer - should be a feed forward type such as DenseLayer
|
||||
* @param timeAxis Time axis, should be 2 for DL4J RNN activations (shape [minibatch, size, sequenceLength])
|
||||
*/
|
||||
public TimeDistributed(@JsonProperty("underlying") @NonNull Layer underlying, @JsonProperty("timeAxis") int timeAxis) {
|
||||
public TimeDistributed(@JsonProperty("underlying") @NonNull Layer underlying, @JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) {
|
||||
super(underlying);
|
||||
this.timeAxis = timeAxis;
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
|
||||
public TimeDistributed(Layer underlying){
|
||||
super(underlying);
|
||||
}
|
||||
|
||||
@Override
|
||||
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
|
||||
|
@ -47,7 +50,7 @@ public class TimeDistributed extends BaseWrapperLayer {
|
|||
NeuralNetConfiguration conf2 = conf.clone();
|
||||
conf2.setLayer(((TimeDistributed) conf2.getLayer()).getUnderlying());
|
||||
return new TimeDistributedLayer(underlying.instantiate(conf2, trainingListeners, layerIndex, layerParamsView,
|
||||
initializeParams, networkDataType), timeAxis);
|
||||
initializeParams, networkDataType), rnnDataFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -59,7 +62,7 @@ public class TimeDistributed extends BaseWrapperLayer {
|
|||
InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType;
|
||||
InputType ff = InputType.feedForward(rnn.getSize());
|
||||
InputType ffOut = underlying.getOutputType(layerIndex, ff);
|
||||
return InputType.recurrent(ffOut.arrayElementsPerExample(), rnn.getTimeSeriesLength());
|
||||
return InputType.recurrent(ffOut.arrayElementsPerExample(), rnn.getTimeSeriesLength(), rnnDataFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -70,6 +73,7 @@ public class TimeDistributed extends BaseWrapperLayer {
|
|||
|
||||
InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType;
|
||||
InputType ff = InputType.feedForward(rnn.getSize());
|
||||
this.rnnDataFormat = rnn.getFormat();
|
||||
underlying.setNIn(ff, override);
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.preprocessor;
|
|||
import lombok.*;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -38,7 +39,7 @@ import java.util.Arrays;
|
|||
* Functionally equivalent to combining CnnToFeedForwardPreProcessor + FeedForwardToRnnPreProcessor<br>
|
||||
* Specifically, this does two things:<br>
|
||||
* (a) Reshape 4d activations out of CNN layer, with shape [timeSeriesLength*miniBatchSize, numChannels, inputHeight, inputWidth])
|
||||
* into 3d (time series) activations (with shape [numExamples, inputHeight*inputWidth*numChannels, timeSeriesLength])
|
||||
* into 3d (time series) activations (with shape [miniBatchSize, inputHeight*inputWidth*numChannels, timeSeriesLength])
|
||||
* for use in RNN layers<br>
|
||||
* (b) Reshapes 3d epsilons (weights.*deltas) out of RNN layer (with shape
|
||||
* [miniBatchSize,inputHeight*inputWidth*numChannels,timeSeriesLength]) into 4d epsilons with shape
|
||||
|
@ -52,6 +53,7 @@ public class CnnToRnnPreProcessor implements InputPreProcessor {
|
|||
private long inputHeight;
|
||||
private long inputWidth;
|
||||
private long numChannels;
|
||||
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
|
||||
@Getter(AccessLevel.NONE)
|
||||
@Setter(AccessLevel.NONE)
|
||||
|
@ -59,11 +61,20 @@ public class CnnToRnnPreProcessor implements InputPreProcessor {
|
|||
|
||||
@JsonCreator
|
||||
public CnnToRnnPreProcessor(@JsonProperty("inputHeight") long inputHeight,
|
||||
@JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) {
|
||||
@JsonProperty("inputWidth") long inputWidth,
|
||||
@JsonProperty("numChannels") long numChannels,
|
||||
@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) {
|
||||
this.inputHeight = inputHeight;
|
||||
this.inputWidth = inputWidth;
|
||||
this.numChannels = numChannels;
|
||||
this.product = inputHeight * inputWidth * numChannels;
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
|
||||
public CnnToRnnPreProcessor(long inputHeight,
|
||||
long inputWidth,
|
||||
long numChannels){
|
||||
this(inputHeight, inputWidth, numChannels, RNNFormat.NCW);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -90,14 +101,19 @@ public class CnnToRnnPreProcessor implements InputPreProcessor {
|
|||
//Second: reshape 2d to 3d, as per FeedForwardToRnnPreProcessor
|
||||
INDArray reshaped = workspaceMgr.dup(ArrayType.ACTIVATIONS, twod, 'f');
|
||||
reshaped = reshaped.reshape('f', miniBatchSize, shape[0] / miniBatchSize, product);
|
||||
return reshaped.permute(0, 2, 1);
|
||||
if (rnnDataFormat == RNNFormat.NCW) {
|
||||
return reshaped.permute(0, 2, 1);
|
||||
}
|
||||
return reshaped;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
if (output.ordering() == 'c' || !Shape.hasDefaultStridesForShape(output))
|
||||
output = output.dup('f');
|
||||
|
||||
if (rnnDataFormat == RNNFormat.NWC){
|
||||
output = output.permute(0, 2, 1);
|
||||
}
|
||||
val shape = output.shape();
|
||||
INDArray output2d;
|
||||
if (shape[0] == 1) {
|
||||
|
@ -122,7 +138,7 @@ public class CnnToRnnPreProcessor implements InputPreProcessor {
|
|||
|
||||
@Override
|
||||
public CnnToRnnPreProcessor clone() {
|
||||
return new CnnToRnnPreProcessor(inputHeight, inputWidth, numChannels);
|
||||
return new CnnToRnnPreProcessor(inputHeight, inputWidth, numChannels, rnnDataFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -133,7 +149,7 @@ public class CnnToRnnPreProcessor implements InputPreProcessor {
|
|||
|
||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||
val outSize = c.getChannels() * c.getHeight() * c.getWidth();
|
||||
return InputType.recurrent(outSize);
|
||||
return InputType.recurrent(outSize, rnnDataFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -21,6 +21,7 @@ import lombok.NoArgsConstructor;
|
|||
import lombok.val;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -28,7 +29,7 @@ import org.nd4j.linalg.api.shape.Shape;
|
|||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
|
@ -48,7 +49,11 @@ import java.util.Arrays;
|
|||
@Data
|
||||
@NoArgsConstructor
|
||||
public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
|
||||
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
|
||||
public FeedForwardToRnnPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Override
|
||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
//Need to reshape FF activations (2d) activations to 3d (for input into RNN layer)
|
||||
|
@ -60,7 +65,10 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
|
|||
|
||||
val shape = input.shape();
|
||||
INDArray reshaped = input.reshape('f', miniBatchSize, shape[0] / miniBatchSize, shape[1]);
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reshaped.permute(0, 2, 1));
|
||||
if (rnnDataFormat == RNNFormat.NCW){
|
||||
reshaped = reshaped.permute(0, 2, 1);
|
||||
}
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reshaped);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -71,6 +79,9 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
|
|||
"Invalid input: expect NDArray with rank 3 (i.e., epsilons from RNN layer)");
|
||||
if (output.ordering() != 'f' || !Shape.hasDefaultStridesForShape(output))
|
||||
output = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, output, 'f');
|
||||
if (rnnDataFormat == RNNFormat.NWC){
|
||||
output = output.permute(0, 2, 1);
|
||||
}
|
||||
val shape = output.shape();
|
||||
|
||||
INDArray ret;
|
||||
|
@ -87,12 +98,7 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
|
|||
|
||||
@Override
|
||||
public FeedForwardToRnnPreProcessor clone() {
|
||||
try {
|
||||
FeedForwardToRnnPreProcessor clone = (FeedForwardToRnnPreProcessor) super.clone();
|
||||
return clone;
|
||||
} catch (CloneNotSupportedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return new FeedForwardToRnnPreProcessor(rnnDataFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -104,10 +110,10 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
|
|||
|
||||
if (inputType.getType() == InputType.Type.FF) {
|
||||
InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward) inputType;
|
||||
return InputType.recurrent(ff.getSize());
|
||||
return InputType.recurrent(ff.getSize(), rnnDataFormat);
|
||||
} else {
|
||||
InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat) inputType;
|
||||
return InputType.recurrent(cf.getFlattenedSize());
|
||||
return InputType.recurrent(cf.getFlattenedSize(), rnnDataFormat);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,8 +19,10 @@ package org.deeplearning4j.nn.conf.preprocessor;
|
|||
import lombok.*;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||
import org.nd4j.enums.RnnDataFormat;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
@ -52,19 +54,27 @@ public class RnnToCnnPreProcessor implements InputPreProcessor {
|
|||
private int inputHeight;
|
||||
private int inputWidth;
|
||||
private int numChannels;
|
||||
|
||||
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
@Getter(AccessLevel.NONE)
|
||||
@Setter(AccessLevel.NONE)
|
||||
private int product;
|
||||
|
||||
public RnnToCnnPreProcessor(@JsonProperty("inputHeight") int inputHeight,
|
||||
@JsonProperty("inputWidth") int inputWidth, @JsonProperty("numChannels") int numChannels) {
|
||||
@JsonProperty("inputWidth") int inputWidth,
|
||||
@JsonProperty("numChannels") int numChannels,
|
||||
@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) {
|
||||
this.inputHeight = inputHeight;
|
||||
this.inputWidth = inputWidth;
|
||||
this.numChannels = numChannels;
|
||||
this.product = inputHeight * inputWidth * numChannels;
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
|
||||
public RnnToCnnPreProcessor(int inputHeight,
|
||||
int inputWidth,
|
||||
int numChannels){
|
||||
this(inputHeight, inputWidth, numChannels, RNNFormat.NCW);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
|
@ -72,6 +82,9 @@ public class RnnToCnnPreProcessor implements InputPreProcessor {
|
|||
input = input.dup('f');
|
||||
//Input: 3d activations (RNN)
|
||||
//Output: 4d activations (CNN)
|
||||
if (rnnDataFormat == RNNFormat.NWC){
|
||||
input = input.permute(0, 2, 1);
|
||||
}
|
||||
val shape = input.shape();
|
||||
INDArray in2d;
|
||||
if (shape[0] == 1) {
|
||||
|
@ -98,14 +111,17 @@ public class RnnToCnnPreProcessor implements InputPreProcessor {
|
|||
val shape = output.shape();
|
||||
//First: reshape 4d to 2d
|
||||
INDArray twod = output.reshape('c', output.size(0), ArrayUtil.prod(output.shape()) / output.size(0));
|
||||
//Second: reshape 2d to 4d
|
||||
//Second: reshape 2d to 3d
|
||||
INDArray reshaped = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, twod, 'f').reshape('f', miniBatchSize, shape[0] / miniBatchSize, product);
|
||||
return reshaped.permute(0, 2, 1);
|
||||
if (rnnDataFormat == RNNFormat.NCW) {
|
||||
reshaped = reshaped.permute(0, 2, 1);
|
||||
}
|
||||
return reshaped;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RnnToCnnPreProcessor clone() {
|
||||
return new RnnToCnnPreProcessor(inputHeight, inputWidth, numChannels);
|
||||
return new RnnToCnnPreProcessor(inputHeight, inputWidth, numChannels, rnnDataFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,11 +16,14 @@
|
|||
|
||||
package org.deeplearning4j.nn.conf.preprocessor;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -28,6 +31,7 @@ import org.nd4j.linalg.api.shape.Shape;
|
|||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
|
@ -47,8 +51,14 @@ import java.util.Arrays;
|
|||
*/
|
||||
@Data
|
||||
@Slf4j
|
||||
@NoArgsConstructor
|
||||
public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
|
||||
|
||||
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
|
||||
public RnnToFeedForwardPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Override
|
||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
//Need to reshape RNN activations (3d) activations to 2d (for input into feed forward layer)
|
||||
|
@ -59,10 +69,13 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
|
|||
if (input.ordering() != 'f' || !Shape.hasDefaultStridesForShape(input))
|
||||
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'f');
|
||||
|
||||
if (rnnDataFormat == RNNFormat.NWC){
|
||||
input = input.permute(0, 2, 1);
|
||||
}
|
||||
val shape = input.shape();
|
||||
INDArray ret;
|
||||
if (shape[0] == 1) {
|
||||
ret = input.tensorAlongDimension(0, 1, 2).permutei(1, 0); //Edge case: miniBatchSize==1
|
||||
ret = input.tensorAlongDimension(0, 1, 2).permute(1, 0); //Edge case: miniBatchSize==1
|
||||
} else if (shape[2] == 1) {
|
||||
ret = input.tensorAlongDimension(0, 1, 0); //Edge case: timeSeriesLength=1
|
||||
} else {
|
||||
|
@ -85,17 +98,15 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
|
|||
|
||||
val shape = output.shape();
|
||||
INDArray reshaped = output.reshape('f', miniBatchSize, shape[0] / miniBatchSize, shape[1]);
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, reshaped.permute(0, 2, 1));
|
||||
if (rnnDataFormat == RNNFormat.NCW){
|
||||
reshaped = reshaped.permute(0, 2, 1);
|
||||
}
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, reshaped);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RnnToFeedForwardPreProcessor clone() {
|
||||
try {
|
||||
RnnToFeedForwardPreProcessor clone = (RnnToFeedForwardPreProcessor) super.clone();
|
||||
return clone;
|
||||
} catch (CloneNotSupportedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return new RnnToFeedForwardPreProcessor(rnnDataFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,7 +18,10 @@ package org.deeplearning4j.nn.layers.recurrent;
|
|||
|
||||
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.layers.BaseLayer;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
|
@ -26,7 +29,7 @@ import java.util.HashMap;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public abstract class BaseRecurrentLayer<LayerConfT extends org.deeplearning4j.nn.conf.layers.BaseLayer>
|
||||
public abstract class BaseRecurrentLayer<LayerConfT extends org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer>
|
||||
extends BaseLayer<LayerConfT> implements RecurrentLayer {
|
||||
|
||||
/**
|
||||
|
@ -85,4 +88,19 @@ public abstract class BaseRecurrentLayer<LayerConfT extends org.deeplearning4j.n
|
|||
tBpttStateMap.putAll(state);
|
||||
}
|
||||
|
||||
public RNNFormat getDataFormat(){
|
||||
return layerConf().getRnnDataFormat();
|
||||
}
|
||||
|
||||
protected INDArray permuteIfNWC(INDArray arr){
|
||||
if (arr == null){
|
||||
return null;
|
||||
}
|
||||
if (getDataFormat() == RNNFormat.NWC){
|
||||
return arr.permute(0, 2, 1);
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.deeplearning4j.nn.api.TrainingConfig;
|
|||
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
|
||||
import org.deeplearning4j.nn.conf.CacheMode;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
|
@ -78,6 +79,9 @@ public class BidirectionalLayer implements RecurrentLayer {
|
|||
this.paramsView = paramsView;
|
||||
}
|
||||
|
||||
private RNNFormat getRNNDataFormat(){
|
||||
return layerConf.getRNNDataFormat();
|
||||
}
|
||||
@Override
|
||||
public INDArray rnnTimeStep(INDArray input, LayerWorkspaceMgr workspaceMgr) {
|
||||
throw new UnsupportedOperationException("Cannot RnnTimeStep bidirectional layers");
|
||||
|
@ -140,7 +144,10 @@ public class BidirectionalLayer implements RecurrentLayer {
|
|||
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
|
||||
INDArray eFwd;
|
||||
INDArray eBwd;
|
||||
|
||||
boolean permute = getRNNDataFormat() == RNNFormat.NWC && epsilon.rank() == 3;
|
||||
if (permute){
|
||||
epsilon = epsilon.permute(0, 2, 1);
|
||||
}
|
||||
val n = epsilon.size(1)/2;
|
||||
switch (layerConf.getMode()){
|
||||
case ADD:
|
||||
|
@ -165,6 +172,10 @@ public class BidirectionalLayer implements RecurrentLayer {
|
|||
|
||||
eBwd = TimeSeriesUtils.reverseTimeSeries(eBwd, workspaceMgr, ArrayType.BP_WORKING_MEM);
|
||||
|
||||
if (permute){
|
||||
eFwd = eFwd.permute(0, 2, 1);
|
||||
eBwd = eBwd.permute(0, 2, 1);
|
||||
}
|
||||
Pair<Gradient,INDArray> g1 = fwd.backpropGradient(eFwd, workspaceMgr);
|
||||
Pair<Gradient,INDArray> g2 = bwd.backpropGradient(eBwd, workspaceMgr);
|
||||
|
||||
|
@ -176,7 +187,9 @@ public class BidirectionalLayer implements RecurrentLayer {
|
|||
g.gradientForVariable().put(BidirectionalParamInitializer.BACKWARD_PREFIX + e.getKey(), e.getValue());
|
||||
}
|
||||
|
||||
INDArray g2Reversed = TimeSeriesUtils.reverseTimeSeries(g2.getRight(), workspaceMgr, ArrayType.BP_WORKING_MEM);
|
||||
INDArray g2Right = permute ? g2.getRight().permute(0, 2, 1): g2.getRight();
|
||||
INDArray g2Reversed = TimeSeriesUtils.reverseTimeSeries(g2Right, workspaceMgr, ArrayType.BP_WORKING_MEM);
|
||||
g2Reversed = permute? g2Reversed.permute(0, 2, 1): g2Reversed;
|
||||
INDArray epsOut = g1.getRight().addi(g2Reversed);
|
||||
|
||||
return new Pair<>(g, epsOut);
|
||||
|
@ -186,25 +199,38 @@ public class BidirectionalLayer implements RecurrentLayer {
|
|||
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
|
||||
INDArray out1 = fwd.activate(training, workspaceMgr);
|
||||
INDArray out2 = bwd.activate(training, workspaceMgr);
|
||||
boolean permute = getRNNDataFormat() == RNNFormat.NWC && out1.rank() == 3;
|
||||
if(permute){
|
||||
out1 = out1.permute(0, 2, 1);
|
||||
out2 = out2.permute(0, 2, 1);
|
||||
}
|
||||
//Reverse the output time series. Note: when using LastTimeStepLayer, output can be rank 2
|
||||
out2 = out2.rank() == 2 ? out2 : TimeSeriesUtils.reverseTimeSeries(out2, workspaceMgr, ArrayType.FF_WORKING_MEM);
|
||||
|
||||
INDArray ret;
|
||||
switch (layerConf.getMode()){
|
||||
case ADD:
|
||||
return out1.addi(out2);
|
||||
ret = out1.addi(out2);
|
||||
break;
|
||||
case MUL:
|
||||
//TODO may be more efficient ways than this...
|
||||
this.outFwd = out1.detach();
|
||||
this.outBwd = out2.detach();
|
||||
return workspaceMgr.dup(ArrayType.ACTIVATIONS, out1).muli(out2);
|
||||
ret = workspaceMgr.dup(ArrayType.ACTIVATIONS, out1).muli(out2);
|
||||
break;
|
||||
case AVERAGE:
|
||||
return out1.addi(out2).muli(0.5);
|
||||
ret = out1.addi(out2).muli(0.5);
|
||||
break;
|
||||
case CONCAT:
|
||||
INDArray ret = Nd4j.concat(1, out1, out2);
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret);
|
||||
ret = Nd4j.concat(1, out1, out2);
|
||||
ret = workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret);
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("Unknown mode: " + layerConf.getMode());
|
||||
}
|
||||
if (permute){
|
||||
ret = ret.permute(0, 2, 1);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -465,7 +491,9 @@ public class BidirectionalLayer implements RecurrentLayer {
|
|||
public void setInput(INDArray input, LayerWorkspaceMgr layerWorkspaceMgr) {
|
||||
this.input = input;
|
||||
fwd.setInput(input, layerWorkspaceMgr);
|
||||
|
||||
if (getRNNDataFormat() == RNNFormat.NWC){
|
||||
input = input.permute(0, 2, 1);
|
||||
}
|
||||
INDArray reversed;
|
||||
if(!input.isAttached()){
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
|
@ -478,6 +506,9 @@ public class BidirectionalLayer implements RecurrentLayer {
|
|||
reversed = TimeSeriesUtils.reverseTimeSeries(input);
|
||||
}
|
||||
}
|
||||
if (getRNNDataFormat() == RNNFormat.NWC){
|
||||
reversed = reversed.permute(0, 2, 1);
|
||||
}
|
||||
bwd.setInput(reversed, layerWorkspaceMgr);
|
||||
}
|
||||
|
||||
|
|
|
@ -88,12 +88,12 @@ public class GravesBidirectionalLSTM
|
|||
}
|
||||
|
||||
final FwdPassReturn fwdPass = activateHelperDirectional(true, null, null, true, true, workspaceMgr);
|
||||
|
||||
fwdPass.fwdPassOutput = permuteIfNWC(fwdPass.fwdPassOutput);
|
||||
final Pair<Gradient, INDArray> forwardsGradient = LSTMHelpers.backpropGradientHelper(this,
|
||||
this.conf,
|
||||
this.layerConf().getGateActivationFn(), this.input,
|
||||
this.layerConf().getGateActivationFn(), permuteIfNWC(this.input),
|
||||
getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS),
|
||||
getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), epsilon,
|
||||
getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), permuteIfNWC(epsilon),
|
||||
truncatedBPTT, tbpttBackwardLength, fwdPass, true,
|
||||
GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS,
|
||||
GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS,
|
||||
|
@ -106,16 +106,17 @@ public class GravesBidirectionalLSTM
|
|||
|
||||
final Pair<Gradient, INDArray> backwardsGradient = LSTMHelpers.backpropGradientHelper(this,
|
||||
this.conf,
|
||||
this.layerConf().getGateActivationFn(), this.input,
|
||||
this.layerConf().getGateActivationFn(), permuteIfNWC(this.input),
|
||||
getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS),
|
||||
getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), epsilon,
|
||||
getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), permuteIfNWC(epsilon),
|
||||
truncatedBPTT, tbpttBackwardLength, backPass, false,
|
||||
GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS,
|
||||
GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS,
|
||||
GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS, gradientViews, maskArray, true,
|
||||
null, workspaceMgr, layerConf().isHelperAllowFallback());
|
||||
|
||||
|
||||
forwardsGradient.setSecond(permuteIfNWC(forwardsGradient.getSecond()));
|
||||
backwardsGradient.setSecond(permuteIfNWC(backwardsGradient.getSecond()));
|
||||
//merge the gradient, which is key value pair of String,INDArray
|
||||
//the keys for forwards and backwards should be different
|
||||
|
||||
|
@ -171,7 +172,7 @@ public class GravesBidirectionalLSTM
|
|||
} else {
|
||||
|
||||
forwardsEval = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(),
|
||||
this.input, getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS),
|
||||
permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS),
|
||||
getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS),
|
||||
getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS), training, null, null,
|
||||
forBackprop || (cacheMode != CacheMode.NONE && training), true,
|
||||
|
@ -179,7 +180,7 @@ public class GravesBidirectionalLSTM
|
|||
forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback());
|
||||
|
||||
backwardsEval = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(),
|
||||
this.input,
|
||||
permuteIfNWC(this.input),
|
||||
getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS),
|
||||
getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS),
|
||||
getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS), training, null, null,
|
||||
|
@ -187,6 +188,8 @@ public class GravesBidirectionalLSTM
|
|||
GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, maskArray, true, null,
|
||||
forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback());
|
||||
|
||||
forwardsEval.fwdPassOutput = permuteIfNWC(forwardsEval.fwdPassOutput);
|
||||
backwardsEval.fwdPassOutput = permuteIfNWC(backwardsEval.fwdPassOutput);
|
||||
cachedPassForward = forwardsEval;
|
||||
cachedPassBackward = backwardsEval;
|
||||
}
|
||||
|
@ -228,10 +231,12 @@ public class GravesBidirectionalLSTM
|
|||
biasKey = GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS;
|
||||
}
|
||||
|
||||
return LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), this.input,
|
||||
FwdPassReturn ret = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input),
|
||||
getParam(recurrentKey), getParam(inputKey), getParam(biasKey), training,
|
||||
prevOutputActivations, prevMemCellState, forBackprop, forwards, inputKey, maskArray, true,
|
||||
null, forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback());
|
||||
ret.fwdPassOutput = permuteIfNWC(ret.fwdPassOutput);
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.CacheMode;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -89,17 +90,17 @@ public class GravesLSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.la
|
|||
} else {
|
||||
fwdPass = activateHelper(true, null, null, true, workspaceMgr);
|
||||
}
|
||||
|
||||
fwdPass.fwdPassOutput = permuteIfNWC(fwdPass.fwdPassOutput);
|
||||
|
||||
Pair<Gradient, INDArray> p = LSTMHelpers.backpropGradientHelper(this,
|
||||
this.conf, this.layerConf().getGateActivationFn(), this.input,
|
||||
recurrentWeights, inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, true,
|
||||
this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input),
|
||||
recurrentWeights, inputWeights, permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true,
|
||||
GravesLSTMParamInitializer.INPUT_WEIGHT_KEY, GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY,
|
||||
GravesLSTMParamInitializer.BIAS_KEY, gradientViews, maskArray, true, null,
|
||||
workspaceMgr, layerConf().isHelperAllowFallback());
|
||||
|
||||
weightNoiseParams.clear();
|
||||
p.setSecond(backpropDropOutIfPresent(p.getSecond()));
|
||||
p.setSecond(permuteIfNWC(backpropDropOutIfPresent(p.getSecond())));
|
||||
return p;
|
||||
}
|
||||
|
||||
|
@ -117,8 +118,8 @@ public class GravesLSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.la
|
|||
private FwdPassReturn activateHelper(final boolean training, final INDArray prevOutputActivations,
|
||||
final INDArray prevMemCellState, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) {
|
||||
assertInputSet(false);
|
||||
Preconditions.checkState(input.rank() == 3,
|
||||
"3D input expected to RNN layer expected, got " + input.rank());
|
||||
Preconditions.checkState(this.input.rank() == 3,
|
||||
"3D input expected to RNN layer expected, got " + this.input.rank());
|
||||
applyDropOutIfNecessary(training, workspaceMgr);
|
||||
|
||||
// if (cacheMode == null)
|
||||
|
@ -136,18 +137,17 @@ public class GravesLSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.la
|
|||
final INDArray recurrentWeights = getParamWithNoise(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY, training, workspaceMgr); //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
|
||||
final INDArray inputWeights = getParamWithNoise(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY, training, workspaceMgr); //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
|
||||
final INDArray biases = getParamWithNoise(GravesLSTMParamInitializer.BIAS_KEY, training, workspaceMgr); //by row: IFOG //Shape: [4,hiddenLayerSize]; order: [bi,bf,bo,bg]^T
|
||||
|
||||
INDArray input = permuteIfNWC(this.input);
|
||||
FwdPassReturn fwd = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(),
|
||||
this.input, recurrentWeights, inputWeights, biases, training, prevOutputActivations,
|
||||
input, recurrentWeights, inputWeights, biases, training, prevOutputActivations,
|
||||
prevMemCellState, forBackprop || (cacheMode != CacheMode.NONE && training), true,
|
||||
GravesLSTMParamInitializer.INPUT_WEIGHT_KEY, maskArray, true, null,
|
||||
cacheMode, workspaceMgr, layerConf().isHelperAllowFallback());
|
||||
|
||||
|
||||
fwd.fwdPassOutput = permuteIfNWC(fwd.fwdPassOutput);
|
||||
if (training && cacheMode != CacheMode.NONE) {
|
||||
cachedFwdPass = fwd;
|
||||
}
|
||||
|
||||
return fwd;
|
||||
}
|
||||
|
||||
|
|
|
@ -123,17 +123,16 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
|
|||
} else {
|
||||
fwdPass = activateHelper(true, null, null, true, workspaceMgr);
|
||||
}
|
||||
|
||||
|
||||
fwdPass.fwdPassOutput = permuteIfNWC(fwdPass.fwdPassOutput);
|
||||
Pair<Gradient,INDArray> p = LSTMHelpers.backpropGradientHelper(this,
|
||||
this.conf, this.layerConf().getGateActivationFn(), this.input,
|
||||
recurrentWeights, inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, true,
|
||||
this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input),
|
||||
recurrentWeights, inputWeights, permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true,
|
||||
LSTMParamInitializer.INPUT_WEIGHT_KEY, LSTMParamInitializer.RECURRENT_WEIGHT_KEY,
|
||||
LSTMParamInitializer.BIAS_KEY, gradientViews, null, false, helper, workspaceMgr,
|
||||
layerConf().isHelperAllowFallback());
|
||||
|
||||
weightNoiseParams.clear();
|
||||
p.setSecond(backpropDropOutIfPresent(p.getSecond()));
|
||||
p.setSecond(permuteIfNWC(backpropDropOutIfPresent(p.getSecond())));
|
||||
return p;
|
||||
}
|
||||
|
||||
|
@ -167,17 +166,18 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
|
|||
final INDArray recurrentWeights = getParamWithNoise(LSTMParamInitializer.RECURRENT_WEIGHT_KEY, training, workspaceMgr); //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
|
||||
final INDArray inputWeights = getParamWithNoise(LSTMParamInitializer.INPUT_WEIGHT_KEY, training, workspaceMgr); //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
|
||||
final INDArray biases = getParamWithNoise(LSTMParamInitializer.BIAS_KEY, training, workspaceMgr); //by row: IFOG //Shape: [4,hiddenLayerSize]; order: [bi,bf,bo,bg]^T
|
||||
|
||||
INDArray input = permuteIfNWC(this.input);
|
||||
FwdPassReturn fwd = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(),
|
||||
this.input, recurrentWeights, inputWeights, biases, training, prevOutputActivations,
|
||||
input, recurrentWeights, inputWeights, biases, training, prevOutputActivations,
|
||||
prevMemCellState, (training && cacheMode != CacheMode.NONE) || forBackprop, true,
|
||||
LSTMParamInitializer.INPUT_WEIGHT_KEY, maskArray, false, helper,
|
||||
forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback());
|
||||
|
||||
fwd.fwdPassOutput = permuteIfNWC(fwd.fwdPassOutput);
|
||||
|
||||
if (training && cacheMode != CacheMode.NONE) {
|
||||
cachedFwdPass = fwd;
|
||||
}
|
||||
|
||||
return fwd;
|
||||
}
|
||||
|
||||
|
|
|
@ -465,7 +465,6 @@ public class LSTMHelpers {
|
|||
val miniBatchSize = epsilon.size(0);
|
||||
boolean is2dInput = epsilon.rank() < 3; //Edge case: T=1 may have shape [miniBatchSize,n^(L+1)], equiv. to [miniBatchSize,n^(L+1),1]
|
||||
val timeSeriesLength = (is2dInput ? 1 : epsilon.size(2));
|
||||
|
||||
INDArray wFFTranspose = null;
|
||||
INDArray wOOTranspose = null;
|
||||
INDArray wGGTranspose = null;
|
||||
|
@ -573,14 +572,14 @@ public class LSTMHelpers {
|
|||
nablaCellState = Nd4j.create(inputWeights.dataType(), new long[]{miniBatchSize, hiddenLayerSize}, 'f');
|
||||
}
|
||||
|
||||
INDArray prevMemCellState = (iTimeIndex == 0 ? fwdPass.prevMemCell : fwdPass.memCellState[(int) (time - inext)]);
|
||||
INDArray prevMemCellState = (iTimeIndex == 0 ? fwdPass.prevMemCell : fwdPass.memCellState[(time - inext)]);
|
||||
INDArray prevHiddenUnitActivation =
|
||||
(iTimeIndex == 0 ? fwdPass.prevAct : fwdPass.fwdPassOutputAsArrays[(int) (time - inext)]);
|
||||
INDArray currMemCellState = fwdPass.memCellState[(int) time];
|
||||
(iTimeIndex == 0 ? fwdPass.prevAct : fwdPass.fwdPassOutputAsArrays[(time - inext)]);
|
||||
INDArray currMemCellState = fwdPass.memCellState[time];
|
||||
|
||||
//LSTM unit output errors (dL/d(a_out)); not to be confused with \delta=dL/d(z_out)
|
||||
INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension((int) time, 1, 0)); //(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv.
|
||||
|
||||
INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension(time, 1, 0)); //(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv.
|
||||
INDArray nablaOut = Shape.toOffsetZeroCopy(epsilonSlice, 'f'); //Shape: [m,n^L]
|
||||
if (iTimeIndex != timeSeriesLength - 1) {
|
||||
//if t == timeSeriesLength-1 then deltaiNext etc are zeros
|
||||
|
@ -666,7 +665,7 @@ public class LSTMHelpers {
|
|||
//Mask array is present: bidirectional RNN -> need to zero out these errors to avoid using errors from a masked time step
|
||||
// to calculate the parameter gradients. Mask array has shape [minibatch, timeSeriesLength] -> get column(this time step)
|
||||
timeStepMaskColumn = maskArray.getColumn(time, true);
|
||||
deltaifogNext.muliColumnVector(timeStepMaskColumn);
|
||||
deltaifogNext.muli(timeStepMaskColumn);
|
||||
//Later, the deltaifogNext is used to calculate: input weight gradients, recurrent weight gradients, bias gradients
|
||||
}
|
||||
|
||||
|
@ -737,7 +736,7 @@ public class LSTMHelpers {
|
|||
if (maskArray != null) {
|
||||
//Mask array is present: bidirectional RNN -> need to zero out these errors to avoid sending anything
|
||||
// but 0s to the layer below at this time step (for the given example)
|
||||
epsilonNextSlice.muliColumnVector(timeStepMaskColumn);
|
||||
epsilonNextSlice.muli(timeStepMaskColumn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.recurrent;
|
|||
import lombok.NonNull;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
|
||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||
|
@ -59,18 +60,41 @@ public class LastTimeStepLayer extends BaseWrapperLayer {
|
|||
|
||||
@Override
|
||||
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
|
||||
INDArray newEps = Nd4j.create(epsilon.dataType(), origOutputShape, 'f');
|
||||
long[] newEpsShape = origOutputShape;
|
||||
boolean nwc = (underlying instanceof BaseRecurrentLayer &&
|
||||
((BaseRecurrentLayer) underlying).getDataFormat() == RNNFormat.NWC)||
|
||||
(underlying instanceof MaskZeroLayer && ((MaskZeroLayer)underlying).getUnderlying() instanceof
|
||||
BaseRecurrentLayer && ((BaseRecurrentLayer)((MaskZeroLayer)underlying).getUnderlying()).getDataFormat()
|
||||
== RNNFormat.NWC);
|
||||
INDArray newEps = Nd4j.create(epsilon.dataType(), newEpsShape, 'f');
|
||||
if(lastTimeStepIdxs == null){
|
||||
//no mask case
|
||||
newEps.put(new INDArrayIndex[]{all(), all(), point(origOutputShape[2]-1)}, epsilon);
|
||||
} else {
|
||||
INDArrayIndex[] arr = new INDArrayIndex[]{null, all(), null};
|
||||
//TODO probably possible to optimize this with reshape + scatter ops...
|
||||
for( int i=0; i<lastTimeStepIdxs.length; i++ ){
|
||||
arr[0] = point(i);
|
||||
arr[2] = point(lastTimeStepIdxs[i]);
|
||||
newEps.put(arr, epsilon.getRow(i));
|
||||
if (nwc){
|
||||
newEps.put(new INDArrayIndex[]{all(), point(origOutputShape[1]-1), all()}, epsilon);
|
||||
}
|
||||
else{
|
||||
newEps.put(new INDArrayIndex[]{all(), all(), point(origOutputShape[2]-1)}, epsilon);
|
||||
}
|
||||
} else {
|
||||
if (nwc){
|
||||
INDArrayIndex[] arr = new INDArrayIndex[]{null, null, all()};
|
||||
//TODO probably possible to optimize this with reshape + scatter ops...
|
||||
for( int i=0; i<lastTimeStepIdxs.length; i++ ){
|
||||
arr[0] = point(i);
|
||||
arr[1] = point(lastTimeStepIdxs[i]);
|
||||
newEps.put(arr, epsilon.getRow(i));
|
||||
}
|
||||
}
|
||||
else{
|
||||
INDArrayIndex[] arr = new INDArrayIndex[]{null, all(), null};
|
||||
//TODO probably possible to optimize this with reshape + scatter ops...
|
||||
for( int i=0; i<lastTimeStepIdxs.length; i++ ){
|
||||
arr[0] = point(i);
|
||||
arr[2] = point(lastTimeStepIdxs[i]);
|
||||
newEps.put(arr, epsilon.getRow(i));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return underlying.backpropGradient(newEps, workspaceMgr);
|
||||
}
|
||||
|
@ -103,10 +127,18 @@ public class LastTimeStepLayer extends BaseWrapperLayer {
|
|||
"rank " + in.rank() + " with shape " + Arrays.toString(in.shape()));
|
||||
}
|
||||
origOutputShape = in.shape();
|
||||
boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(underlying.conf().getLayer()) == RNNFormat.NWC;
|
||||
// underlying instanceof BaseRecurrentLayer && ((BaseRecurrentLayer)underlying).getDataFormat() == RNNFormat.NWC)||
|
||||
// underlying instanceof MaskZeroLayer && ((MaskZeroLayer)underlying).getUnderlying() instanceof BaseRecurrentLayer &&
|
||||
// ((BaseRecurrentLayer)((MaskZeroLayer)underlying).getUnderlying()).getDataFormat() == RNNFormat.NWC;
|
||||
if (nwc){
|
||||
in = in.permute(0, 2, 1);
|
||||
}
|
||||
|
||||
INDArray mask = underlying.getMaskArray();
|
||||
Pair<INDArray,int[]> p = TimeSeriesUtils.pullLastTimeSteps(in, mask, workspaceMgr, arrayType);
|
||||
lastTimeStepIdxs = p.getSecond();
|
||||
|
||||
return p.getFirst();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,6 +30,9 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
import lombok.NonNull;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
|
||||
import static org.deeplearning4j.nn.conf.RNNFormat.NCW;
|
||||
import static org.deeplearning4j.nn.conf.RNNFormat.NWC;
|
||||
|
||||
/**
|
||||
* Masks timesteps with activation equal to the specified masking value, defaulting to 0.0.
|
||||
* Assumes that the input shape is [batch_size, input_size, timesteps].
|
||||
|
@ -76,7 +79,11 @@ public class MaskZeroLayer extends BaseWrapperLayer {
|
|||
throw new IllegalArgumentException("Expected input of shape [batch_size, timestep_input_size, timestep], " +
|
||||
"got shape "+Arrays.toString(input.shape()) + " instead");
|
||||
}
|
||||
INDArray mask = input.eq(maskingValue).castTo(input.dataType()).sum(1).neq(input.shape()[1]);
|
||||
if ((underlying instanceof BaseRecurrentLayer &&
|
||||
((BaseRecurrentLayer)underlying).getDataFormat() == NWC)){
|
||||
input = input.permute(0, 2, 1);
|
||||
}
|
||||
INDArray mask = input.eq(maskingValue).castTo(input.dataType()).sum(1).neq(input.shape()[1]).castTo(input.dataType());
|
||||
underlying.setMaskArray(mask.detach());
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.deeplearning4j.eval.Evaluation;
|
|||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.BaseLayer;
|
||||
|
@ -60,6 +61,8 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
|
|||
@Override
|
||||
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
|
||||
assertInputSet(true);
|
||||
INDArray input = this.input;
|
||||
INDArray labels = this.labels;
|
||||
if (input.rank() != 3)
|
||||
throw new UnsupportedOperationException(
|
||||
"Input is not rank 3. Expected rank 3 input of shape [minibatch, size, sequenceLength]. Got input with rank " +
|
||||
|
@ -67,6 +70,10 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
|
|||
if (labels == null)
|
||||
throw new IllegalStateException("Labels are not set (null)");
|
||||
|
||||
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
|
||||
input = input.permute(0, 2, 1);
|
||||
labels = labels.permute(0, 2, 1);
|
||||
}
|
||||
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
|
||||
Preconditions.checkState(input.size(2) == labels.size(2), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
|
||||
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
|
||||
|
@ -90,7 +97,9 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
|
|||
INDArray delta2d = lossFunction.computeGradient(labels2d, input2d.dup(input2d.ordering()), layerConf().getActivationFn(), maskReshaped);
|
||||
|
||||
INDArray delta3d = TimeSeriesUtils.reshape2dTo3d(delta2d, input.size(0), workspaceMgr, ArrayType.ACTIVATION_GRAD);
|
||||
|
||||
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
|
||||
delta3d = delta3d.permute(0, 2, 1);
|
||||
}
|
||||
// grab the empty gradient
|
||||
Gradient gradient = new DefaultGradient();
|
||||
return new Pair<>(gradient, delta3d);
|
||||
|
@ -159,13 +168,21 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
|
|||
@Override
|
||||
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
|
||||
assertInputSet(false);
|
||||
INDArray input = this.input;
|
||||
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
|
||||
input = input.permute(0, 2, 1);
|
||||
}
|
||||
if (input.rank() != 3)
|
||||
throw new UnsupportedOperationException(
|
||||
"Input must be rank 3. Got input with rank " + input.rank() + " " + layerId());
|
||||
|
||||
INDArray as2d = TimeSeriesUtils.reshape3dTo2d(input);
|
||||
INDArray out2d = layerConf().getActivationFn().getActivation(workspaceMgr.dup(ArrayType.ACTIVATIONS, as2d, as2d.ordering()), training);
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, TimeSeriesUtils.reshape2dTo3d(out2d, input.size(0), workspaceMgr, ArrayType.ACTIVATIONS));
|
||||
INDArray ret = workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, TimeSeriesUtils.reshape2dTo3d(out2d, input.size(0), workspaceMgr, ArrayType.ACTIVATIONS));
|
||||
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
|
||||
ret = ret.permute(0, 2, 1);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -196,6 +213,12 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
|
|||
|
||||
@Override
|
||||
public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) {
|
||||
INDArray input = this.input;
|
||||
INDArray labels = this.labels;
|
||||
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
|
||||
input = input.permute(0, 2, 1);
|
||||
labels = input.permute(0, 2, 1);
|
||||
}
|
||||
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.FF_WORKING_MEM);
|
||||
INDArray labels2d = TimeSeriesUtils.reshape3dTo2d(labels, workspaceMgr, ArrayType.FF_WORKING_MEM);
|
||||
INDArray maskReshaped;
|
||||
|
@ -228,10 +251,14 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
|
|||
@Override
|
||||
public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr workspaceMgr) {
|
||||
//For RNN: need to sum up the score over each time step before returning.
|
||||
|
||||
INDArray input = this.input;
|
||||
INDArray labels = this.labels;
|
||||
if (input == null || labels == null)
|
||||
throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
|
||||
|
||||
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
|
||||
input = input.permute(0, 2, 1);
|
||||
labels = input.permute(0, 2, 1);
|
||||
}
|
||||
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.FF_WORKING_MEM);
|
||||
INDArray labels2d = TimeSeriesUtils.reshape3dTo2d(labels, workspaceMgr, ArrayType.FF_WORKING_MEM);
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.recurrent;
|
|||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.BaseOutputLayer;
|
||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||
|
@ -57,11 +58,15 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
|||
"Input is not rank 3. RnnOutputLayer expects rank 3 input with shape [minibatch, layerInSize, sequenceLength]." +
|
||||
" Got input with rank " + input.rank() + " and shape " + Arrays.toString(input.shape()) + " - " + layerId());
|
||||
}
|
||||
int td = (layerConf().getRnnDataFormat()==RNNFormat.NCW)? 2: 1;
|
||||
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
|
||||
Preconditions.checkState(input.size(2) == labels.size(2), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
|
||||
Preconditions.checkState(input.size(td) == labels.size(td), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
|
||||
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
|
||||
|
||||
INDArray inputTemp = input;
|
||||
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
|
||||
this.input = input.permute(0, 2, 1);
|
||||
}
|
||||
this.input = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.BP_WORKING_MEM);
|
||||
|
||||
applyDropOutIfNecessary(true, workspaceMgr); //Edge case: we skip OutputLayer forward pass during training as this isn't required to calculate gradients
|
||||
|
@ -71,7 +76,9 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
|||
INDArray epsilon2d = gradAndEpsilonNext.getSecond();
|
||||
|
||||
INDArray epsilon3d = TimeSeriesUtils.reshape2dTo3d(epsilon2d, input.size(0), workspaceMgr, ArrayType.ACTIVATION_GRAD);
|
||||
|
||||
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
|
||||
epsilon3d = epsilon3d.permute(0, 2, 1);
|
||||
}
|
||||
weightNoiseParams.clear();
|
||||
|
||||
//epsilon3d = backpropDropOutIfPresent(epsilon3d);
|
||||
|
@ -104,6 +111,7 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
|||
if (input.rank() == 3) {
|
||||
//Case when called from RnnOutputLayer
|
||||
INDArray inputTemp = input;
|
||||
input = (layerConf().getRnnDataFormat()==RNNFormat.NWC)? input.permute(0, 2, 1):input;
|
||||
input = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.FF_WORKING_MEM);
|
||||
INDArray out = super.preOutput(training, workspaceMgr);
|
||||
this.input = inputTemp;
|
||||
|
@ -117,13 +125,17 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
|||
|
||||
@Override
|
||||
protected INDArray getLabels2d(LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) {
|
||||
if (labels.rank() == 3)
|
||||
INDArray labels = this.labels;
|
||||
if (labels.rank() == 3){
|
||||
labels = (layerConf().getRnnDataFormat()==RNNFormat.NWC)?labels.permute(0, 2, 1):labels;
|
||||
return TimeSeriesUtils.reshape3dTo2d(labels, workspaceMgr, arrayType);
|
||||
}
|
||||
return labels;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
|
||||
INDArray input = this.input;
|
||||
if (input.rank() != 3)
|
||||
throw new UnsupportedOperationException(
|
||||
"Input must be rank 3. Got input with rank " + input.rank() + " " + layerId());
|
||||
|
@ -131,6 +143,9 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
|||
INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr);
|
||||
|
||||
applyDropOutIfNecessary(training, workspaceMgr);
|
||||
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
|
||||
input = input.permute(0, 2, 1);
|
||||
}
|
||||
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input.castTo(W.dataType()), workspaceMgr, ArrayType.FF_WORKING_MEM);
|
||||
|
||||
INDArray act2d = layerConf().getActivationFn().getActivation(input2d.mmul(W).addiRowVector(b), training);
|
||||
|
@ -144,7 +159,11 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
|||
}
|
||||
}
|
||||
|
||||
return TimeSeriesUtils.reshape2dTo3d(act2d, input.size(0), workspaceMgr, ArrayType.ACTIVATIONS);
|
||||
INDArray ret = TimeSeriesUtils.reshape2dTo3d(act2d, input.size(0), workspaceMgr, ArrayType.ACTIVATIONS);
|
||||
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
|
||||
ret = ret.permute(0, 2, 1);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.recurrent;
|
|||
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
|
||||
|
@ -50,6 +51,7 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
|
|||
public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn> {
|
||||
public static final String STATE_KEY_PREV_ACTIVATION = "prevAct";
|
||||
|
||||
|
||||
public SimpleRnn(NeuralNetConfiguration conf, DataType dataType) {
|
||||
super(conf, dataType);
|
||||
}
|
||||
|
@ -92,6 +94,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
|||
val nOut = layerConf().getNOut();
|
||||
|
||||
INDArray input = this.input.castTo(dataType); //No-op if correct type
|
||||
input = permuteIfNWC(input);
|
||||
|
||||
//First: Do forward pass to get gate activations and Zs
|
||||
Quad<INDArray,INDArray, INDArray, INDArray> p = activateHelper(null, true, true, workspaceMgr);
|
||||
|
@ -125,8 +128,9 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
|||
} else {
|
||||
end = 0;
|
||||
}
|
||||
epsilon = permuteIfNWC(epsilon);
|
||||
for( long i = tsLength-1; i>= end; i--){
|
||||
INDArray dldaCurrent = epsilon.get(all(), all(), point(i));
|
||||
INDArray dldaCurrent = epsilon.get(all(), all(), point(i)).dup();
|
||||
INDArray aCurrent = p.getFirst().get(all(), all(), point(i));
|
||||
INDArray zCurrent = p.getSecond().get(all(), all(), point(i));
|
||||
INDArray nCurrent = (hasLayerNorm() ? p.getThird().get(all(), all(), point(i)) : null);
|
||||
|
@ -141,7 +145,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
|||
//Recurrent weight gradients:
|
||||
Nd4j.gemm(aCurrent, dldzNext, rwg, true, false, 1.0, 1.0);
|
||||
}
|
||||
INDArray dldzCurrent = a.backprop(zCurrent.dup(), dldaCurrent.dup()).getFirst();
|
||||
INDArray dldzCurrent = a.backprop(zCurrent.dup(), dldaCurrent).getFirst();
|
||||
|
||||
//Handle masking
|
||||
INDArray maskCol = null;
|
||||
|
@ -200,6 +204,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
|||
}
|
||||
|
||||
epsOut = backpropDropOutIfPresent(epsOut);
|
||||
epsOut = permuteIfNWC(epsOut);
|
||||
return new Pair<>(grad, epsOut);
|
||||
}
|
||||
|
||||
|
@ -224,6 +229,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
|||
applyDropOutIfNecessary(training, workspaceMgr);
|
||||
|
||||
INDArray input = this.input.castTo(dataType); //No-op if correct type
|
||||
input = permuteIfNWC(input);
|
||||
val m = input.size(0);
|
||||
val tsLength = input.size(2);
|
||||
val nOut = layerConf().getNOut();
|
||||
|
@ -300,7 +306,12 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
|||
Nd4j.getExecutioner().exec(new BroadcastMulOp(outZ, mask, outZ, 0, 2));
|
||||
}
|
||||
}
|
||||
|
||||
if (!forBackprop) {
|
||||
out = permuteIfNWC(out);
|
||||
outZ = permuteIfNWC(outZ);
|
||||
outPreNorm = permuteIfNWC(outPreNorm);
|
||||
recPreNorm = permuteIfNWC(recPreNorm);
|
||||
}
|
||||
return new Quad<>(out, outZ, outPreNorm, recPreNorm);
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package org.deeplearning4j.nn.layers.recurrent;
|
|||
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
|
@ -22,11 +23,11 @@ import org.nd4j.linalg.util.ArrayUtil;
|
|||
*/
|
||||
public class TimeDistributedLayer extends BaseWrapperLayer {
|
||||
|
||||
private final int timeAxis;
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public TimeDistributedLayer(Layer underlying, int timeAxis) {
|
||||
public TimeDistributedLayer(Layer underlying, RNNFormat rnnDataFormat) {
|
||||
super(underlying);
|
||||
this.timeAxis = timeAxis;
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
|
||||
|
||||
|
@ -56,7 +57,7 @@ public class TimeDistributedLayer extends BaseWrapperLayer {
|
|||
protected INDArray reshape(INDArray array){
|
||||
//Reshape the time axis to the minibatch axis
|
||||
//For example, for RNN -> FF (dense time distributed): [mb, size, seqLen] -> [mb x seqLen, size]
|
||||
int axis = timeAxis;
|
||||
int axis = (rnnDataFormat == RNNFormat.NCW) ? 2 : 1;
|
||||
if(axis < 0)
|
||||
axis += array.rank();
|
||||
|
||||
|
@ -91,7 +92,7 @@ public class TimeDistributedLayer extends BaseWrapperLayer {
|
|||
|
||||
protected INDArray revertReshape(INDArray toRevert, long minibatch){
|
||||
|
||||
int axis = timeAxis;
|
||||
int axis = (rnnDataFormat == RNNFormat.NCW)? 2 : 1;
|
||||
if(axis < 0)
|
||||
axis += (toRevert.rank()+1);
|
||||
|
||||
|
|
|
@ -17,6 +17,13 @@
|
|||
package org.deeplearning4j.util;
|
||||
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
|
@ -233,6 +240,12 @@ public class TimeSeriesUtils {
|
|||
return outReshape.reshape('f', in.size(0), in.size(1), in.size(2));
|
||||
}
|
||||
|
||||
public static INDArray reverseTimeSeries(INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType, RNNFormat dataFormat){
|
||||
if (dataFormat == RNNFormat.NCW){
|
||||
return reverseTimeSeries(in, workspaceMgr, arrayType);
|
||||
}
|
||||
return reverseTimeSeries(in.permute(0, 2, 1), workspaceMgr, arrayType).permute(0, 2, 1);
|
||||
}
|
||||
/**
|
||||
* Reverse an input time series along the time dimension
|
||||
*
|
||||
|
@ -423,4 +436,25 @@ public class TimeSeriesUtils {
|
|||
|
||||
return new Pair<>(workspaceMgr.leverageTo(arrayType, out), fwdPassTimeSteps);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the {@link RNNFormat} from the RNN layer, accounting for the presence of wrapper layers like Bidirectional,
|
||||
* LastTimeStep, etc
|
||||
* @param layer Layer to get the RNNFormat from
|
||||
*/
|
||||
public static RNNFormat getFormatFromRnnLayer(Layer layer){
|
||||
if(layer instanceof BaseRecurrentLayer){
|
||||
return ((BaseRecurrentLayer) layer).getRnnDataFormat();
|
||||
} else if(layer instanceof MaskZeroLayer){
|
||||
return getFormatFromRnnLayer(((MaskZeroLayer) layer).getUnderlying());
|
||||
} else if(layer instanceof Bidirectional){
|
||||
return getFormatFromRnnLayer(((Bidirectional) layer).getFwd());
|
||||
} else if(layer instanceof LastTimeStep){
|
||||
return getFormatFromRnnLayer(((LastTimeStep) layer).getUnderlying());
|
||||
} else if(layer instanceof TimeDistributed){
|
||||
return ((TimeDistributed) layer).getRnnDataFormat();
|
||||
} else {
|
||||
throw new IllegalStateException("Unable to get RNNFormat from layer of type: " + layer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue