DL4J Fixes (#204)

* Fix issue with recently introduced exception handling system in MultiLayerNetwork/ComputationGraph

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for SpaceToBatch layer

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8133 DL4J SpaceToBatch gradient fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-30 23:00:53 +10:00 committed by GitHub
parent 54e320a255
commit 3f3b676ce5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 92 additions and 25 deletions

View File

@ -22,12 +22,16 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.fail; import java.util.Map;
import static org.junit.Assert.*;
/** /**
* A set of tests to ensure that useful exceptions are thrown on invalid input * A set of tests to ensure that useful exceptions are thrown on invalid input
@ -267,23 +271,44 @@ public class TestInvalidInput extends BaseDL4JTest {
//Idea: Using rnnTimeStep with a different number of examples between calls //Idea: Using rnnTimeStep with a different number of examples between calls
//(i.e., not calling reset between time steps) //(i.e., not calling reset between time steps)
for(String layerType : new String[]{"simple", "lstm", "graves"}) {
Layer l;
switch (layerType){
case "simple":
l = new SimpleRnn.Builder().nIn(5).nOut(5).build();
break;
case "lstm":
l = new LSTM.Builder().nIn(5).nOut(5).build();
break;
case "graves":
l = new GravesLSTM.Builder().nIn(5).nOut(5).build();
break;
default:
throw new RuntimeException();
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
.layer(0, new GravesLSTM.Builder().nIn(5).nOut(5).build()) .layer(l)
.layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build(); .layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
net.rnnTimeStep(Nd4j.create(3, 5, 10)); net.rnnTimeStep(Nd4j.create(3, 5, 10));
Map<String, INDArray> m = net.rnnGetPreviousState(0);
assertNotNull(m);
assertFalse(m.isEmpty());
try { try {
net.rnnTimeStep(Nd4j.create(5, 5, 10)); net.rnnTimeStep(Nd4j.create(5, 5, 10));
fail("Expected DL4JException"); fail("Expected Exception - " + layerType);
} catch (DL4JException e) {
System.out.println("testInvalidRnnTimeStep(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); // e.printStackTrace();
fail("Expected DL4JException"); String msg = e.getMessage();
assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch"));
}
} }
} }
} }

View File

@ -28,6 +28,7 @@ import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
@ -343,6 +344,12 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
//Also check compgraph:
ComputationGraph cg = net.toComputationGraph();
gradOK = GradientCheckUtil.checkGradients(cg, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{input}, new INDArray[]{labels});
assertTrue(msg + " - compgraph", gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }

View File

@ -2460,6 +2460,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
} }
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
if(t != null){
if(t instanceof RuntimeException){
throw ((RuntimeException)t);
}
throw new RuntimeException("Error during neural network forward pass", t);
}
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) { if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) {
WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active at the end of outputOfLayerDetached"); WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active at the end of outputOfLayerDetached");
} else { } else {
@ -2780,6 +2787,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
} }
} }
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
if(t != null){
if(t instanceof RuntimeException){
throw ((RuntimeException)t);
}
throw new RuntimeException("Error during neural network backpropagation calculation", t);
}
} }
//Now, add the gradients in the order we need them in for flattening (same as params order) //Now, add the gradients in the order we need them in for flattening (same as params order)
@ -3312,8 +3326,11 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
@Override @Override
public int batchSize() { public int batchSize() {
//In 99+% of cases, the input and labels dimension 0 size should be identical
//The only real exceptions: space to batch, and batch to space layers
//In those cases, we should base it on the labels size, as this impacts gradient calculation
// FIXME: int cast // FIXME: int cast
return (int) inputs[0].size(0); return labels == null || labels[0] == null ? (int) inputs[0].size(0) : (int)labels[0].size(0);
} }
@Override @Override

View File

@ -70,12 +70,12 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
private INDArray getBlocksArray() { private INDArray getBlocksArray() {
int[] intBlocks = layerConf().getBlocks(); int[] intBlocks = layerConf().getBlocks();
return Nd4j.create(new double[] {intBlocks[0], intBlocks[1]}); return Nd4j.createFromArray(intBlocks);
} }
private INDArray getPaddingArray() { private INDArray getPaddingArray() {
int[][] intPad = layerConf().getPadding(); int[][] intPad = layerConf().getPadding();
return Nd4j.create( new double[][] { {intPad[0][0], intPad[0][1]}, {intPad[1][0], intPad[1][1]}}); return Nd4j.createFromArray(intPad);
} }

View File

@ -32,6 +32,7 @@ import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNConvHelper; import org.deeplearning4j.nn.layers.mkldnn.MKLDNNConvHelper;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -188,12 +189,9 @@ public class LSTMHelpers {
} }
//Input validation: check that if past state is provided, that it has same //Input validation: check that if past state is provided, that it has same
//These can be different if user forgets to call rnnClearPreviousState() between calls of rnnTimeStep //These can be different if user forgets to call rnnClearPreviousState() between calls of rnnTimeStep
if (prevOutputActivations != null && prevOutputActivations.size(0) != input.size(0)) { Preconditions.checkState(prevOutputActivations == null || prevOutputActivations.size(0) == input.size(0),
throw new DL4JInvalidInputException("Previous activations (stored state) number of examples = " "Invalid RNN previous state (last time step activations/initialization): rnnTimeStep with different minibatch size, or forgot to call rnnClearPreviousState between batches?" +
+ prevOutputActivations.size(0) + " but input array number of examples = " + input.size(0) " Previous step output = [batch, nIn] = %ndShape, current input = [batch, nIn, seqLength] = %ndShape", prevOutputActivations, input);
+ ". Possible cause: using rnnTimeStep() without calling"
+ " rnnClearPreviousState() between different sequences?");
}
//initialize prevOutputActivations to zeroes //initialize prevOutputActivations to zeroes
if (prevOutputActivations == null) { if (prevOutputActivations == null) {

View File

@ -217,6 +217,9 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
assertInputSet(false); assertInputSet(false);
Preconditions.checkState(input.rank() == 3, Preconditions.checkState(input.rank() == 3,
"3D input expected to RNN layer expected, got " + input.rank()); "3D input expected to RNN layer expected, got " + input.rank());
Preconditions.checkState(prevStepOut == null || prevStepOut.size(0) == input.size(0),
"Invalid RNN previous state (last time step activations/initialization): rnnTimeStep with different minibatch size, or forgot to call rnnClearPreviousState between batches?" +
" Previous step output = [batch, nIn] = %ndShape, current input = [batch, nIn, seqLength] = %ndShape", prevStepOut, input);
applyDropOutIfNecessary(training, workspaceMgr); applyDropOutIfNecessary(training, workspaceMgr);

View File

@ -436,8 +436,11 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
@Override @Override
public int batchSize() { public int batchSize() {
//In 99+% of cases, the input and labels dimension 0 size should be identical
//The only real exceptions: space to batch, and batch to space layers
//In those cases, we should base it on the labels size, as this impacts gradient calculation
// FIXME: int cast // FIXME: int cast
return (int) input.size(0); return labels == null ? (int) input.size(0) : (int)labels.size(0);
} }
@Override @Override
@ -1362,6 +1365,13 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
if(t != null){
if(t instanceof RuntimeException){
throw ((RuntimeException)t);
}
throw new RuntimeException("Error during neural network forward pass", t);
}
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) { if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) {
WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active at the end of outputOfLayerDetached", true); WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active at the end of outputOfLayerDetached", true);
} else { } else {
@ -2007,6 +2017,13 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
} }
} }
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
if(t != null){
if(t instanceof RuntimeException){
throw ((RuntimeException)t);
}
throw new RuntimeException("Error during neural network forward pass", t);
}
} }
if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) { if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) {