DL4J Fixes (#204)
* Fix issue with recently introduced exception handling system in MultiLayerNetwork/ComputationGraph Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for SpaceToBatch layer Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8133 DL4J SpaceToBatch gradient fix Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
54e320a255
commit
3f3b676ce5
|
@ -22,12 +22,16 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.fail;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* A set of tests to ensure that useful exceptions are thrown on invalid input
|
||||
|
@ -267,23 +271,44 @@ public class TestInvalidInput extends BaseDL4JTest {
|
|||
//Idea: Using rnnTimeStep with a different number of examples between calls
|
||||
//(i.e., not calling reset between time steps)
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
|
||||
.layer(0, new GravesLSTM.Builder().nIn(5).nOut(5).build())
|
||||
.layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build();
|
||||
for(String layerType : new String[]{"simple", "lstm", "graves"}) {
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
Layer l;
|
||||
switch (layerType){
|
||||
case "simple":
|
||||
l = new SimpleRnn.Builder().nIn(5).nOut(5).build();
|
||||
break;
|
||||
case "lstm":
|
||||
l = new LSTM.Builder().nIn(5).nOut(5).build();
|
||||
break;
|
||||
case "graves":
|
||||
l = new GravesLSTM.Builder().nIn(5).nOut(5).build();
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException();
|
||||
}
|
||||
|
||||
net.rnnTimeStep(Nd4j.create(3, 5, 10));
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
|
||||
.layer(l)
|
||||
.layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build();
|
||||
|
||||
try {
|
||||
net.rnnTimeStep(Nd4j.create(5, 5, 10));
|
||||
fail("Expected DL4JException");
|
||||
} catch (DL4JException e) {
|
||||
System.out.println("testInvalidRnnTimeStep(): " + e.getMessage());
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
fail("Expected DL4JException");
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
net.rnnTimeStep(Nd4j.create(3, 5, 10));
|
||||
|
||||
Map<String, INDArray> m = net.rnnGetPreviousState(0);
|
||||
assertNotNull(m);
|
||||
assertFalse(m.isEmpty());
|
||||
|
||||
try {
|
||||
net.rnnTimeStep(Nd4j.create(5, 5, 10));
|
||||
fail("Expected Exception - " + layerType);
|
||||
} catch (Exception e) {
|
||||
// e.printStackTrace();
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
|||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
|
@ -343,6 +344,12 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
assertTrue(msg, gradOK);
|
||||
|
||||
//Also check compgraph:
|
||||
ComputationGraph cg = net.toComputationGraph();
|
||||
gradOK = GradientCheckUtil.checkGradients(cg, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{input}, new INDArray[]{labels});
|
||||
assertTrue(msg + " - compgraph", gradOK);
|
||||
|
||||
TestUtils.testModelSerialization(net);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2460,6 +2460,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
}
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||
|
||||
if(t != null){
|
||||
if(t instanceof RuntimeException){
|
||||
throw ((RuntimeException)t);
|
||||
}
|
||||
throw new RuntimeException("Error during neural network forward pass", t);
|
||||
}
|
||||
|
||||
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) {
|
||||
WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active at the end of outputOfLayerDetached");
|
||||
} else {
|
||||
|
@ -2780,6 +2787,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
}
|
||||
}
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||
|
||||
if(t != null){
|
||||
if(t instanceof RuntimeException){
|
||||
throw ((RuntimeException)t);
|
||||
}
|
||||
throw new RuntimeException("Error during neural network backpropagation calculation", t);
|
||||
}
|
||||
}
|
||||
|
||||
//Now, add the gradients in the order we need them in for flattening (same as params order)
|
||||
|
@ -3312,8 +3326,11 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
|
||||
@Override
|
||||
public int batchSize() {
|
||||
//In 99+% of cases, the input and labels dimension 0 size should be identical
|
||||
//The only real exceptions: space to batch, and batch to space layers
|
||||
//In those cases, we should base it on the labels size, as this impacts gradient calculation
|
||||
// FIXME: int cast
|
||||
return (int) inputs[0].size(0);
|
||||
return labels == null || labels[0] == null ? (int) inputs[0].size(0) : (int)labels[0].size(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -70,12 +70,12 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
|||
|
||||
private INDArray getBlocksArray() {
|
||||
int[] intBlocks = layerConf().getBlocks();
|
||||
return Nd4j.create(new double[] {intBlocks[0], intBlocks[1]});
|
||||
return Nd4j.createFromArray(intBlocks);
|
||||
}
|
||||
|
||||
private INDArray getPaddingArray() {
|
||||
int[][] intPad = layerConf().getPadding();
|
||||
return Nd4j.create( new double[][] { {intPad[0][0], intPad[0][1]}, {intPad[1][0], intPad[1][1]}});
|
||||
return Nd4j.createFromArray(intPad);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.deeplearning4j.nn.layers.BaseLayer;
|
|||
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNConvHelper;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
|
@ -188,12 +189,9 @@ public class LSTMHelpers {
|
|||
}
|
||||
//Input validation: check that if past state is provided, that it has same
|
||||
//These can be different if user forgets to call rnnClearPreviousState() between calls of rnnTimeStep
|
||||
if (prevOutputActivations != null && prevOutputActivations.size(0) != input.size(0)) {
|
||||
throw new DL4JInvalidInputException("Previous activations (stored state) number of examples = "
|
||||
+ prevOutputActivations.size(0) + " but input array number of examples = " + input.size(0)
|
||||
+ ". Possible cause: using rnnTimeStep() without calling"
|
||||
+ " rnnClearPreviousState() between different sequences?");
|
||||
}
|
||||
Preconditions.checkState(prevOutputActivations == null || prevOutputActivations.size(0) == input.size(0),
|
||||
"Invalid RNN previous state (last time step activations/initialization): rnnTimeStep with different minibatch size, or forgot to call rnnClearPreviousState between batches?" +
|
||||
" Previous step output = [batch, nIn] = %ndShape, current input = [batch, nIn, seqLength] = %ndShape", prevOutputActivations, input);
|
||||
|
||||
//initialize prevOutputActivations to zeroes
|
||||
if (prevOutputActivations == null) {
|
||||
|
|
|
@ -217,6 +217,9 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
|||
assertInputSet(false);
|
||||
Preconditions.checkState(input.rank() == 3,
|
||||
"3D input expected to RNN layer expected, got " + input.rank());
|
||||
Preconditions.checkState(prevStepOut == null || prevStepOut.size(0) == input.size(0),
|
||||
"Invalid RNN previous state (last time step activations/initialization): rnnTimeStep with different minibatch size, or forgot to call rnnClearPreviousState between batches?" +
|
||||
" Previous step output = [batch, nIn] = %ndShape, current input = [batch, nIn, seqLength] = %ndShape", prevStepOut, input);
|
||||
|
||||
applyDropOutIfNecessary(training, workspaceMgr);
|
||||
|
||||
|
|
|
@ -436,8 +436,11 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
|
||||
@Override
|
||||
public int batchSize() {
|
||||
//In 99+% of cases, the input and labels dimension 0 size should be identical
|
||||
//The only real exceptions: space to batch, and batch to space layers
|
||||
//In those cases, we should base it on the labels size, as this impacts gradient calculation
|
||||
// FIXME: int cast
|
||||
return (int) input.size(0);
|
||||
return labels == null ? (int) input.size(0) : (int)labels.size(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1362,6 +1365,13 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||
|
||||
if(t != null){
|
||||
if(t instanceof RuntimeException){
|
||||
throw ((RuntimeException)t);
|
||||
}
|
||||
throw new RuntimeException("Error during neural network forward pass", t);
|
||||
}
|
||||
|
||||
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) {
|
||||
WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active at the end of outputOfLayerDetached", true);
|
||||
} else {
|
||||
|
@ -2007,6 +2017,13 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
}
|
||||
}
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||
|
||||
if(t != null){
|
||||
if(t instanceof RuntimeException){
|
||||
throw ((RuntimeException)t);
|
||||
}
|
||||
throw new RuntimeException("Error during neural network forward pass", t);
|
||||
}
|
||||
}
|
||||
|
||||
if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) {
|
||||
|
|
Loading…
Reference in New Issue