2019-06-06 15:21:15 +03:00

808 lines
40 KiB
Java

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.multilayer;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer;
import org.deeplearning4j.nn.layers.recurrent.GravesLSTM;
import org.deeplearning4j.nn.layers.recurrent.LSTM;
import org.deeplearning4j.nn.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.*;
@Slf4j
public class MultiLayerTestRNN extends BaseDL4JTest {
@Test
public void testGravesLSTMInit() {
int nIn = 8;
int nOut = 25;
int nHiddenUnits = 17;
MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder()
.list().layer(0,
new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder()
.nIn(nIn).nOut(nHiddenUnits)
.activation(Activation.TANH).build())
.layer(1, new RnnOutputLayer.Builder(LossFunction.MSE).nIn(nHiddenUnits)
.nOut(nOut)
.activation(Activation.TANH).build())
.build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
//Ensure that we have the correct number weights and biases, that these have correct shape etc.
Layer layer = network.getLayer(0);
assertTrue(layer instanceof GravesLSTM);
Map<String, INDArray> paramTable = layer.paramTable();
assertTrue(paramTable.size() == 3); //2 sets of weights, 1 set of biases
INDArray recurrentWeights = paramTable.get(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY);
assertArrayEquals(recurrentWeights.shape(), new long[] {nHiddenUnits, 4 * nHiddenUnits + 3}); //Should be shape: [layerSize,4*layerSize+3]
INDArray inputWeights = paramTable.get(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY);
assertArrayEquals(inputWeights.shape(), new long[] {nIn, 4 * nHiddenUnits}); //Should be shape: [nIn,4*layerSize]
INDArray biases = paramTable.get(GravesLSTMParamInitializer.BIAS_KEY);
assertArrayEquals(biases.shape(), new long[] {1, 4 * nHiddenUnits}); //Should be shape: [1,4*layerSize]
//Want forget gate biases to be initialized to > 0. See parameter initializer for details
INDArray forgetGateBiases =
biases.get(NDArrayIndex.point(0), NDArrayIndex.interval(nHiddenUnits, 2 * nHiddenUnits));
INDArray gt = forgetGateBiases.gt(0);
INDArray gtSum = gt.castTo(DataType.INT).sum(Integer.MAX_VALUE);
int count = gtSum.getInt(0);
assertEquals(nHiddenUnits, count);
val nParams = recurrentWeights.length() + inputWeights.length() + biases.length();
assertTrue(nParams == layer.numParams());
}
@Test
public void testGravesTLSTMInitStacked() {
int nIn = 8;
int nOut = 25;
int[] nHiddenUnits = {17, 19, 23};
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
.layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(17)
.activation(Activation.TANH).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(17).nOut(19)
.activation(Activation.TANH).build())
.layer(2, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(19).nOut(23)
.activation(Activation.TANH).build())
.layer(3, new RnnOutputLayer.Builder(LossFunction.MSE).nIn(23).nOut(nOut)
.activation(Activation.TANH).build())
.build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
//Ensure that we have the correct number weights and biases, that these have correct shape etc. for each layer
for (int i = 0; i < nHiddenUnits.length; i++) {
Layer layer = network.getLayer(i);
assertTrue(layer instanceof GravesLSTM);
Map<String, INDArray> paramTable = layer.paramTable();
assertTrue(paramTable.size() == 3); //2 sets of weights, 1 set of biases
int layerNIn = (i == 0 ? nIn : nHiddenUnits[i - 1]);
INDArray recurrentWeights = paramTable.get(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY);
assertArrayEquals(recurrentWeights.shape(), new long[] {nHiddenUnits[i], 4 * nHiddenUnits[i] + 3}); //Should be shape: [layerSize,4*layerSize+3]
INDArray inputWeights = paramTable.get(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY);
assertArrayEquals(inputWeights.shape(), new long[] {layerNIn, 4 * nHiddenUnits[i]}); //Should be shape: [nIn,4*layerSize]
INDArray biases = paramTable.get(GravesLSTMParamInitializer.BIAS_KEY);
assertArrayEquals(biases.shape(), new long[] {1, 4 * nHiddenUnits[i]}); //Should be shape: [1,4*layerSize]
//Want forget gate biases to be initialized to > 0. See parameter initializer for details
INDArray forgetGateBiases = biases.get(NDArrayIndex.point(0),
NDArrayIndex.interval(nHiddenUnits[i], 2 * nHiddenUnits[i]));
INDArray gt = forgetGateBiases.gt(0).castTo(DataType.INT);
INDArray gtSum = gt.sum(Integer.MAX_VALUE);
double count = gtSum.getDouble(0);
assertEquals(nHiddenUnits[i], (int)count);
val nParams = recurrentWeights.length() + inputWeights.length() + biases.length();
assertTrue(nParams == layer.numParams());
}
}
@Test
public void testRnnStateMethods() {
Nd4j.getRandom().setSeed(12345);
int timeSeriesLength = 6;
MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder()
.list().layer(0,
new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder()
.nIn(5).nOut(7).activation(Activation.TANH)
.dist(new NormalDistribution(0, 0.5)).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7)
.nOut(8).activation(Activation.TANH)
.dist(new NormalDistribution(0,
0.5))
.build())
.layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT)
.nIn(8).nOut(4)
.activation(Activation.SOFTMAX)
.dist(new NormalDistribution(0, 0.5)).build())
.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
INDArray input = Nd4j.rand(new int[] {3, 5, timeSeriesLength});
List<INDArray> allOutputActivations = mln.feedForward(input, true);
INDArray outAct = allOutputActivations.get(3);
INDArray outRnnTimeStep = mln.rnnTimeStep(input);
assertTrue(outAct.equals(outRnnTimeStep)); //Should be identical here
Map<String, INDArray> currStateL0 = mln.rnnGetPreviousState(0);
Map<String, INDArray> currStateL1 = mln.rnnGetPreviousState(1);
assertTrue(currStateL0.size() == 2);
assertTrue(currStateL1.size() == 2);
INDArray lastActL0 = currStateL0.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION);
INDArray lastMemL0 = currStateL0.get(GravesLSTM.STATE_KEY_PREV_MEMCELL);
assertTrue(lastActL0 != null && lastMemL0 != null);
INDArray lastActL1 = currStateL1.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION);
INDArray lastMemL1 = currStateL1.get(GravesLSTM.STATE_KEY_PREV_MEMCELL);
assertTrue(lastActL1 != null && lastMemL1 != null);
INDArray expectedLastActL0 = allOutputActivations.get(1).tensorAlongDimension(timeSeriesLength - 1, 1, 0);
assertTrue(expectedLastActL0.equals(lastActL0));
INDArray expectedLastActL1 = allOutputActivations.get(2).tensorAlongDimension(timeSeriesLength - 1, 1, 0);
assertTrue(expectedLastActL1.equals(lastActL1));
//Check clearing and setting of state:
mln.rnnClearPreviousState();
assertTrue(mln.rnnGetPreviousState(0).isEmpty());
assertTrue(mln.rnnGetPreviousState(1).isEmpty());
mln.rnnSetPreviousState(0, currStateL0);
assertTrue(mln.rnnGetPreviousState(0).size() == 2);
mln.rnnSetPreviousState(1, currStateL1);
assertTrue(mln.rnnGetPreviousState(1).size() == 2);
}
@Test
public void testRnnTimeStepLayers() {
for( int layerType=0; layerType<3; layerType++ ) {
org.deeplearning4j.nn.conf.layers.Layer l0;
org.deeplearning4j.nn.conf.layers.Layer l1;
String lastActKey;
if(layerType == 0){
l0 = new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(5).nOut(7)
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 0.5)).build();
l1 = new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8)
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 0.5)).build();
lastActKey = GravesLSTM.STATE_KEY_PREV_ACTIVATION;
} else if(layerType == 1){
l0 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().nIn(5).nOut(7)
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 0.5)).build();
l1 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().nIn(7).nOut(8)
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 0.5)).build();
lastActKey = LSTM.STATE_KEY_PREV_ACTIVATION;
} else {
l0 = new org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn.Builder().nIn(5).nOut(7)
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 0.5)).build();
l1 = new org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn.Builder().nIn(7).nOut(8)
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 0.5)).build();
lastActKey = SimpleRnn.STATE_KEY_PREV_ACTIVATION;
}
log.info("Starting test for layer type: {}", l0.getClass().getSimpleName());
Nd4j.getRandom().setSeed(12345);
int timeSeriesLength = 12;
//4 layer network: 2 GravesLSTM + DenseLayer + RnnOutputLayer. Hence also tests preprocessors.
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list()
.layer(0, l0)
.layer(1, l1)
.layer(2, new DenseLayer.Builder().nIn(8).nOut(9).activation(Activation.TANH)
.dist(
new NormalDistribution(0,
0.5))
.build())
.layer(3, new RnnOutputLayer.Builder(LossFunction.MCXENT)
.nIn(9).nOut(4).activation(Activation.SOFTMAX)
.dist(new NormalDistribution(0, 0.5))
.build())
.inputPreProcessor(2, new RnnToFeedForwardPreProcessor())
.inputPreProcessor(3, new FeedForwardToRnnPreProcessor()).build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
INDArray input = Nd4j.rand(new int[]{3, 5, timeSeriesLength});
List<INDArray> allOutputActivations = mln.feedForward(input, true);
INDArray fullOutL0 = allOutputActivations.get(1);
INDArray fullOutL1 = allOutputActivations.get(2);
INDArray fullOutL3 = allOutputActivations.get(4);
int[] inputLengths = {1, 2, 3, 4, 6, 12};
//Do steps of length 1, then of length 2, ..., 12
//Should get the same result regardless of step size; should be identical to standard forward pass
for (int i = 0; i < inputLengths.length; i++) {
int inLength = inputLengths[i];
int nSteps = timeSeriesLength / inLength; //each of length inLength
mln.rnnClearPreviousState();
mln.setInputMiniBatchSize(1); //Reset; should be set by rnnTimeStep method
for (int j = 0; j < nSteps; j++) {
int startTimeRange = j * inLength;
int endTimeRange = startTimeRange + inLength;
INDArray inputSubset;
if (inLength == 1) { //Workaround to nd4j bug
val sizes = new long[]{input.size(0), input.size(1), 1};
inputSubset = Nd4j.create(sizes);
inputSubset.tensorAlongDimension(0, 1, 0).assign(input.get(NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.point(startTimeRange)));
} else {
inputSubset = input.get(NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.interval(startTimeRange, endTimeRange));
}
if (inLength > 1)
assertTrue(inputSubset.size(2) == inLength);
INDArray out = mln.rnnTimeStep(inputSubset);
INDArray expOutSubset;
if (inLength == 1) {
val sizes = new long[]{fullOutL3.size(0), fullOutL3.size(1), 1};
expOutSubset = Nd4j.create(DataType.FLOAT, sizes);
expOutSubset.tensorAlongDimension(0, 1, 0).assign(fullOutL3.get(NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(startTimeRange)));
} else {
expOutSubset = fullOutL3.get(NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.interval(startTimeRange, endTimeRange));
}
assertEquals(expOutSubset, out);
Map<String, INDArray> currL0State = mln.rnnGetPreviousState(0);
Map<String, INDArray> currL1State = mln.rnnGetPreviousState(1);
INDArray lastActL0 = currL0State.get(lastActKey);
INDArray lastActL1 = currL1State.get(lastActKey);
INDArray expLastActL0 = fullOutL0.tensorAlongDimension(endTimeRange - 1, 1, 0);
INDArray expLastActL1 = fullOutL1.tensorAlongDimension(endTimeRange - 1, 1, 0);
assertEquals(expLastActL0, lastActL0);
assertEquals(expLastActL1, lastActL1);
}
}
}
}
@Test
public void testRnnTimeStep2dInput() {
Nd4j.getRandom().setSeed(12345);
int timeSeriesLength = 6;
MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder()
.list().layer(0,
new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder()
.nIn(5).nOut(7).activation(Activation.TANH)
.dist(new NormalDistribution(0, 0.5)).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7)
.nOut(8).activation(Activation.TANH)
.dist(new NormalDistribution(0,
0.5))
.build())
.layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT)
.nIn(8).nOut(4)
.activation(Activation.SOFTMAX)
.dist(new NormalDistribution(0, 0.5)).build())
.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
INDArray input3d = Nd4j.rand(new long[] {3, 5, timeSeriesLength});
INDArray out3d = mln.rnnTimeStep(input3d);
assertArrayEquals(out3d.shape(), new long[] {3, 4, timeSeriesLength});
mln.rnnClearPreviousState();
for (int i = 0; i < timeSeriesLength; i++) {
INDArray input2d = input3d.tensorAlongDimension(i, 1, 0);
INDArray out2d = mln.rnnTimeStep(input2d);
assertArrayEquals(out2d.shape(), new long[] {3, 4});
INDArray expOut2d = out3d.tensorAlongDimension(i, 1, 0);
assertEquals(out2d, expOut2d);
}
//Check same but for input of size [3,5,1]. Expect [3,4,1] out
mln.rnnClearPreviousState();
for (int i = 0; i < timeSeriesLength; i++) {
INDArray temp = Nd4j.create(new int[] {3, 5, 1});
temp.tensorAlongDimension(0, 1, 0).assign(input3d.tensorAlongDimension(i, 1, 0));
INDArray out3dSlice = mln.rnnTimeStep(temp);
assertArrayEquals(out3dSlice.shape(), new long[] {3, 4, 1});
assertTrue(out3dSlice.tensorAlongDimension(0, 1, 0).equals(out3d.tensorAlongDimension(i, 1, 0)));
}
}
@Test
public void testTruncatedBPTTVsBPTT() {
//Under some (limited) circumstances, we expect BPTT and truncated BPTT to be identical
//Specifically TBPTT over entire data vector
int timeSeriesLength = 12;
int miniBatchSize = 7;
int nIn = 5;
int nOut = 4;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.trainingWorkspaceMode(WorkspaceMode.NONE).inferenceWorkspaceMode(WorkspaceMode.NONE)
.list()
.layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7)
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 0.5)).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8)
.activation(Activation.TANH)
.dist(
new NormalDistribution(0,
0.5))
.build())
.layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT)
.nIn(8).nOut(nOut).activation(Activation.SOFTMAX)
.dist(new NormalDistribution(0, 0.5))
.build())
.build();
assertEquals(BackpropType.Standard, conf.getBackpropType());
MultiLayerConfiguration confTBPTT = new NeuralNetConfiguration.Builder().seed(12345)
.trainingWorkspaceMode(WorkspaceMode.NONE).inferenceWorkspaceMode(WorkspaceMode.NONE)
.list()
.layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7)
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 0.5)).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8)
.activation(Activation.TANH)
.dist(
new NormalDistribution(0,
0.5))
.build())
.layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT)
.nIn(8).nOut(nOut).activation(Activation.SOFTMAX)
.dist(new NormalDistribution(0, 0.5))
.build())
.backpropType(BackpropType.TruncatedBPTT).tBPTTBackwardLength(timeSeriesLength)
.tBPTTForwardLength(timeSeriesLength).build();
Nd4j.getRandom().setSeed(12345);
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
Nd4j.getRandom().setSeed(12345);
MultiLayerNetwork mlnTBPTT = new MultiLayerNetwork(confTBPTT);
mlnTBPTT.init();
mlnTBPTT.clearTbpttState = false;
assertEquals(BackpropType.TruncatedBPTT, mlnTBPTT.getLayerWiseConfigurations().getBackpropType());
assertEquals(timeSeriesLength, mlnTBPTT.getLayerWiseConfigurations().getTbpttFwdLength());
assertEquals(timeSeriesLength, mlnTBPTT.getLayerWiseConfigurations().getTbpttBackLength());
INDArray inputData = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength});
INDArray labels = Nd4j.rand(new int[] {miniBatchSize, nOut, timeSeriesLength});
mln.setInput(inputData);
mln.setLabels(labels);
mlnTBPTT.setInput(inputData);
mlnTBPTT.setLabels(labels);
mln.computeGradientAndScore();
mlnTBPTT.computeGradientAndScore();
Pair<Gradient, Double> mlnPair = mln.gradientAndScore();
Pair<Gradient, Double> tbpttPair = mlnTBPTT.gradientAndScore();
assertEquals(mlnPair.getFirst().gradientForVariable(), tbpttPair.getFirst().gradientForVariable());
assertEquals(mlnPair.getSecond(), tbpttPair.getSecond(), 1e-8);
//Check states: expect stateMap to be empty but tBpttStateMap to not be
Map<String, INDArray> l0StateMLN = mln.rnnGetPreviousState(0);
Map<String, INDArray> l0StateTBPTT = mlnTBPTT.rnnGetPreviousState(0);
Map<String, INDArray> l1StateMLN = mln.rnnGetPreviousState(0);
Map<String, INDArray> l1StateTBPTT = mlnTBPTT.rnnGetPreviousState(0);
Map<String, INDArray> l0TBPTTStateMLN = ((BaseRecurrentLayer<?>) mln.getLayer(0)).rnnGetTBPTTState();
Map<String, INDArray> l0TBPTTStateTBPTT = ((BaseRecurrentLayer<?>) mlnTBPTT.getLayer(0)).rnnGetTBPTTState();
Map<String, INDArray> l1TBPTTStateMLN = ((BaseRecurrentLayer<?>) mln.getLayer(1)).rnnGetTBPTTState();
Map<String, INDArray> l1TBPTTStateTBPTT = ((BaseRecurrentLayer<?>) mlnTBPTT.getLayer(1)).rnnGetTBPTTState();
assertTrue(l0StateMLN.isEmpty());
assertTrue(l0StateTBPTT.isEmpty());
assertTrue(l1StateMLN.isEmpty());
assertTrue(l1StateTBPTT.isEmpty());
assertTrue(l0TBPTTStateMLN.isEmpty());
assertEquals(2, l0TBPTTStateTBPTT.size());
assertTrue(l1TBPTTStateMLN.isEmpty());
assertEquals(2, l1TBPTTStateTBPTT.size());
INDArray tbpttActL0 = l0TBPTTStateTBPTT.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION);
INDArray tbpttActL1 = l1TBPTTStateTBPTT.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION);
List<INDArray> activations = mln.feedForward(inputData, true);
INDArray l0Act = activations.get(1);
INDArray l1Act = activations.get(2);
INDArray expL0Act = l0Act.tensorAlongDimension(timeSeriesLength - 1, 1, 0);
INDArray expL1Act = l1Act.tensorAlongDimension(timeSeriesLength - 1, 1, 0);
assertEquals(tbpttActL0, expL0Act);
assertEquals(tbpttActL1, expL1Act);
}
@Test
public void testRnnActivateUsingStoredState() {
int timeSeriesLength = 12;
int miniBatchSize = 7;
int nIn = 5;
int nOut = 4;
int nTimeSlices = 5;
MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder().seed(12345).list().layer(0,
new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7)
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 0.5)).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7)
.nOut(8).activation(Activation.TANH)
.dist(new NormalDistribution(0,
0.5))
.build())
.layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT)
.nIn(8).nOut(nOut)
.activation(Activation.SOFTMAX)
.dist(new NormalDistribution(0, 0.5)).build())
.build();
Nd4j.getRandom().setSeed(12345);
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
INDArray inputLong = Nd4j.rand(new int[] {miniBatchSize, nIn, nTimeSlices * timeSeriesLength});
INDArray input = inputLong.get(NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.interval(0, timeSeriesLength));
List<INDArray> outStandard = mln.feedForward(input, true);
List<INDArray> outRnnAct = mln.rnnActivateUsingStoredState(input, true, true);
//As initially state is zeros: expect these to be the same
assertEquals(outStandard, outRnnAct);
//Furthermore, expect multiple calls to this function to be the same:
for (int i = 0; i < 3; i++) {
assertEquals(outStandard, mln.rnnActivateUsingStoredState(input, true, true));
}
List<INDArray> outStandardLong = mln.feedForward(inputLong, true);
BaseRecurrentLayer<?> l0 = ((BaseRecurrentLayer<?>) mln.getLayer(0));
BaseRecurrentLayer<?> l1 = ((BaseRecurrentLayer<?>) mln.getLayer(1));
for (int i = 0; i < nTimeSlices; i++) {
INDArray inSlice = inputLong.get(NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.interval(i * timeSeriesLength, (i + 1) * timeSeriesLength));
List<INDArray> outSlice = mln.rnnActivateUsingStoredState(inSlice, true, true);
List<INDArray> expOut = new ArrayList<>();
for (INDArray temp : outStandardLong) {
expOut.add(temp.get(NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.interval(i * timeSeriesLength, (i + 1) * timeSeriesLength)));
}
for (int j = 0; j < expOut.size(); j++) {
INDArray exp = expOut.get(j);
INDArray act = outSlice.get(j);
System.out.println(j);
System.out.println(exp.sub(act));
assertEquals(exp, act);
}
assertEquals(expOut, outSlice);
//Again, expect multiple calls to give the same output
for (int j = 0; j < 3; j++) {
outSlice = mln.rnnActivateUsingStoredState(inSlice, true, true);
assertEquals(expOut, outSlice);
}
l0.rnnSetPreviousState(l0.rnnGetTBPTTState());
l1.rnnSetPreviousState(l1.rnnGetTBPTTState());
}
}
@Test
public void testTruncatedBPTTSimple() {
//Extremely simple test of the 'does it throw an exception' variety
int timeSeriesLength = 12;
int miniBatchSize = 7;
int nIn = 5;
int nOut = 4;
int nTimeSlices = 20;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7)
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 0.5)).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8)
.activation(Activation.TANH)
.dist(
new NormalDistribution(0,
0.5))
.build())
.layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT)
.nIn(8).nOut(nOut).activation(Activation.SOFTMAX)
.dist(new NormalDistribution(0, 0.5))
.build())
.backpropType(BackpropType.TruncatedBPTT)
.tBPTTBackwardLength(timeSeriesLength).tBPTTForwardLength(timeSeriesLength).build();
Nd4j.getRandom().setSeed(12345);
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
INDArray inputLong = Nd4j.rand(new int[] {miniBatchSize, nIn, nTimeSlices * timeSeriesLength});
INDArray labelsLong = Nd4j.rand(new int[] {miniBatchSize, nOut, nTimeSlices * timeSeriesLength});
mln.fit(inputLong, labelsLong);
}
@Test
public void testTruncatedBPTTWithMasking() {
//Extremely simple test of the 'does it throw an exception' variety
int timeSeriesLength = 100;
int tbpttLength = 10;
int miniBatchSize = 7;
int nIn = 5;
int nOut = 4;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7)
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 0.5)).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8)
.activation(Activation.TANH)
.dist(
new NormalDistribution(0,
0.5))
.build())
.layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT)
.nIn(8).nOut(nOut).activation(Activation.SOFTMAX)
.dist(new NormalDistribution(0, 0.5))
.build())
.backpropType(BackpropType.TruncatedBPTT)
.tBPTTBackwardLength(tbpttLength).tBPTTForwardLength(tbpttLength).build();
Nd4j.getRandom().setSeed(12345);
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
INDArray features = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength});
INDArray labels = Nd4j.rand(new int[] {miniBatchSize, nOut, timeSeriesLength});
INDArray maskArrayInput = Nd4j.ones(miniBatchSize, timeSeriesLength);
INDArray maskArrayOutput = Nd4j.ones(miniBatchSize, timeSeriesLength);
DataSet ds = new DataSet(features, labels, maskArrayInput, maskArrayOutput);
mln.fit(ds);
}
@Test
public void testRnnTimeStepWithPreprocessor() {
MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.list()
.layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(10)
.nOut(10).activation(Activation.TANH).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(10)
.nOut(10).activation(Activation.TANH).build())
.layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(10).nOut(10).build())
.inputPreProcessor(0, new FeedForwardToRnnPreProcessor())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Nd4j.rand(1, 10);
net.rnnTimeStep(in);
}
@Test
public void testRnnTimeStepWithPreprocessorGraph() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.graphBuilder().addInputs("in")
.addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(10).nOut(10)
.activation(Activation.TANH).build(), "in")
.addLayer("1", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(10).nOut(10)
.activation(Activation.TANH).build(), "0")
.addLayer("2", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1")
.setOutputs("2").inputPreProcessor("0", new FeedForwardToRnnPreProcessor())
.build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
INDArray in = Nd4j.rand(1, 10);
net.rnnTimeStep(in);
}
@Test
public void testTBPTTLongerThanTS() {
//Extremely simple test of the 'does it throw an exception' variety
int timeSeriesLength = 20;
int tbpttLength = 1000;
int miniBatchSize = 7;
int nIn = 5;
int nOut = 4;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.weightInit(WeightInit.XAVIER).list()
.layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7)
.activation(Activation.TANH).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8)
.activation(Activation.TANH).build())
.layer(2, new RnnOutputLayer.Builder(LossFunction.MSE).nIn(8).nOut(nOut)
.activation(Activation.IDENTITY).build())
.backpropType(BackpropType.TruncatedBPTT)
.tBPTTBackwardLength(tbpttLength).tBPTTForwardLength(tbpttLength).build();
Nd4j.getRandom().setSeed(12345);
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
INDArray features = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength});
INDArray labels = Nd4j.rand(new int[] {miniBatchSize, nOut, timeSeriesLength});
INDArray maskArrayInput = Nd4j.ones(miniBatchSize, timeSeriesLength);
INDArray maskArrayOutput = Nd4j.ones(miniBatchSize, timeSeriesLength);
DataSet ds = new DataSet(features, labels, maskArrayInput, maskArrayOutput);
INDArray initialParams = mln.params().dup();
mln.fit(ds);
INDArray afterParams = mln.params();
assertNotEquals(initialParams, afterParams);
}
@Test
public void testInvalidTPBTT() {
int nIn = 8;
int nOut = 25;
int nHiddenUnits = 17;
try {
new NeuralNetConfiguration.Builder()
.list()
.layer(new org.deeplearning4j.nn.conf.layers.LSTM.Builder().nIn(nIn).nOut(nHiddenUnits).build())
.layer(new GlobalPoolingLayer())
.layer(new OutputLayer.Builder(LossFunction.MSE).nIn(nHiddenUnits)
.nOut(nOut)
.activation(Activation.TANH).build())
.backpropType(BackpropType.TruncatedBPTT)
.build();
fail("Exception expected");
} catch (IllegalStateException e){
// e.printStackTrace();
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
}
}
@Test
public void testWrapperLayerGetPreviousState(){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new FrozenLayer(new org.deeplearning4j.nn.conf.layers.LSTM.Builder()
.nIn(5).nOut(5).build()))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Nd4j.create(1, 5, 2);
net.rnnTimeStep(in);
Map<String,INDArray> m = net.rnnGetPreviousState(0);
assertNotNull(m);
assertEquals(2, m.size()); //activation and cell state
net.rnnSetPreviousState(0, m);
ComputationGraph cg = net.toComputationGraph();
cg.rnnTimeStep(in);
m = cg.rnnGetPreviousState(0);
assertNotNull(m);
assertEquals(2, m.size()); //activation and cell state
cg.rnnSetPreviousState(0, m);
}
}