diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java index 377bb4ced..3ee1bcd6c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java @@ -165,35 +165,13 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { .weightInit(WeightInit.XAVIER) .updater(new Sgd(2)) .list() - .layer(0, - new DenseLayer.Builder() - .nIn(4) - .nOut(3) - .build() - ) - .layer(1, - new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder() - .nIn(3) - .nOut(4) - .build() - ) - ) - .layer(2, - new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder() - .nIn(4) - .nOut(2) - .build() - ) - ).layer(3, - new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new OutputLayer.Builder(LossFunctions.LossFunction.MSE) - .nIn(2) - .nOut(1) - .build() - ) - ) + .layer(new DenseLayer.Builder().nIn(4).nOut(3).build()) + .layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( + new DenseLayer.Builder().nIn(3).nOut(4).build())) + .layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( + new DenseLayer.Builder().nIn(4).nOut(2).build())) + .layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( + new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build())) .build(); MultiLayerNetwork network = new MultiLayerNetwork(conf1); @@ -238,70 +216,18 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { .seed(12345) .graphBuilder() .addInputs("input") - .addLayer(initialLayer, - new DenseLayer.Builder() - .nIn(4) - .nOut(4) - .build(), - "input" - ) - .addLayer(frozenBranchUnfrozenLayer0, - new DenseLayer.Builder() - .nIn(4) - .nOut(3) - .build(), - initialLayer - ) - .addLayer(frozenBranchFrozenLayer1, - new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder() - .nIn(3) - .nOut(4) - .build() - ), - frozenBranchUnfrozenLayer0 - ) - .addLayer(frozenBranchFrozenLayer2, - new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder() - .nIn(4) - .nOut(2) - .build() - ), - frozenBranchFrozenLayer1 - ) - .addLayer(unfrozenLayer0, - new DenseLayer.Builder() - .nIn(4) - .nOut(4) - .build(), - initialLayer - ) - .addLayer(unfrozenLayer1, - new DenseLayer.Builder() - .nIn(4) - .nOut(2) - .build(), - unfrozenLayer0 - ) - .addLayer(unfrozenBranch2, - new DenseLayer.Builder() - .nIn(2) - .nOut(1) - .build(), - unfrozenLayer1 - ) - .addVertex("merge", - new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2) - .addLayer(frozenBranchOutput, - new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new OutputLayer.Builder(LossFunctions.LossFunction.MSE) - .nIn(3) - .nOut(1) - .build() - ), - "merge" - ) + .addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input") + .addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer) + .addLayer(frozenBranchFrozenLayer1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( + new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0) + .addLayer(frozenBranchFrozenLayer2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( + new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1) + .addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer) + .addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0) + .addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1) + .addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2) + .addLayer(frozenBranchOutput,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( + new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge") .setOutputs(frozenBranchOutput) .build(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java index 5868c0287..d6cab0273 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java @@ -35,6 +35,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.convolution.Convolution; +import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -170,6 +171,8 @@ public class ConvolutionLayer extends BaseLayer ret = null; try { ret = helper.backpropGradient(input, epsilon, k, n, alpha, beta, workspaceMgr); + } catch (ND4JOpProfilerException e){ + throw e; //NaN panic etc for debugging } catch (Throwable t){ if(t.getMessage().contains("Failed to allocate")){ //This is a memory exception - don't fallback to built-in implementation @@ -206,6 +209,8 @@ public class LocalResponseNormalization INDArray activations = null; try { activations = helper.activate(input, training, k, n, alpha, beta, workspaceMgr); + } catch (ND4JOpProfilerException e){ + throw e; //NaN panic etc for debugging } catch (Throwable t){ if(t.getMessage().contains("Failed to allocate")){ //This is a memory exception - don't fallback to built-in implementation diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java new file mode 100644 index 000000000..58ba7fffc --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java @@ -0,0 +1,257 @@ +package org.nd4j.autodiff.listeners.debugging; + +import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.ScalarOp; + +import java.util.Arrays; + +/** + * A listener that logs operation execution for debugging purposes. + * 3 modes are supported:

+ * OPS_ONLY: Only the operations names are printed. For example:
+ * {@code (iter=0,epoch=0,op=1) org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp}
+ * SHAPES_ONLY: Print the operation class, shape info (for inputs/output arrays) as well as any arguments - iArgs, bArgs, tArgs. For example:
+ *
{@code
+ * (iter=1,epoch=0,op=3) org.nd4j.linalg.api.ops.impl.loss.LogLoss
+ * 	iArgs=[3]
+ * 	tArgs=[1.0E-7]
+ * 	Input[0]=Rank: 2, DataType: FLOAT, Offset: 0, Order: c, Shape: [1,2],  Stride: [1,1]
+ * 	Input[1]=Rank: 0, DataType: FLOAT, Offset: 0, Order: c, Shape: [],  Stride: []
+ * 	Input[2]=Rank: 2, DataType: FLOAT, Offset: 0, Order: c, Shape: [1,2],  Stride: [1,1]
+ * 	Outputs[0]=Rank: 0, DataType: FLOAT, Offset: 0, Order: c, Shape: [],  Stride: []
+ * }
+ * 
+ * REPRODUCE: Print runnable Java code that should reproduce that op execution (other than perhaps exact input/output strides). For example:
+ *
{@code
+ * (iter=2,epoch=0,op=1) org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp
+ * DynamicCustomOp op = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp();
+ * INDArray[] inputs = new INDArray[2];
+ * inputs[0] = Nd4j.createFromArray(1.5253239f, 0.8733858f).reshape(1, 2);
+ * inputs[1] = Nd4j.createFromArray(0.483428f, 0.86025196f).reshape(1, 2);
+ * op.addInputArgument(inputs);
+ * INDArray[] outputs = new INDArray[1];
+ * outputs[0] = Nd4j.createFromArray(2.012087f, 1.7303026f).reshape(1, 2);
+ * op.addOutputArgument(outputs);
+ * Nd4j.exec(op);
+ * }
+ * 
+ * + * @author Alex Black + */ +public class ExecDebuggingListener extends BaseListener { + + public enum PrintMode {OPS_ONLY, SHAPES_ONLY, REPRODUCE} + + private final PrintMode printMode; + private final int maxIterations; + private final boolean logIter; + + private long printIterations = 0; + private int lastIter = -1; + private int stepThisIter = 0; + + /** + * @param printMode Print mode, see {@link PrintMode} + * @param maxIterations Maximum number of iterations to print. <= 0 for "all iterations" + * @param logIter If true: prefix iteration/epoch, such as "(iter=1,epoch=0,op=3)" to the output + */ + public ExecDebuggingListener(PrintMode printMode, int maxIterations, boolean logIter){ + this.printMode = printMode; + this.maxIterations = maxIterations; + this.logIter = logIter; + } + + @Override + public void preOpExecution(SameDiff sd, At at, boolean training, SameDiffOp op) { + if(lastIter != at.iteration()){ + lastIter = at.iteration(); + stepThisIter = 0; + printIterations++; + } + + if(maxIterations > 0 && printIterations > maxIterations){ + return; + } + + StringBuilder sb = new StringBuilder(); + if(logIter){ + sb.append("(iter=").append(at.iteration()) + .append(",epoch=").append(at.epoch()) + .append(","); + } + sb.append("op=").append(stepThisIter++) + .append(logIter ? ") " : " - "); + + DifferentialFunction df = op.getOp(); + sb.append(op.getOp().getClass().getName()); + CustomOp co = df instanceof CustomOp ? (CustomOp) df : null; + Op lOp = df instanceof Op ? (Op) df : null; + if(printMode == PrintMode.OPS_ONLY){ + sb.append("\n"); + } else if(printMode == PrintMode.SHAPES_ONLY){ + if(co != null){ + if(co.iArgs() != null && co.iArgs().length > 0) { + sb.append("\n\tiArgs=").append(Arrays.toString(co.iArgs())); + } + if(co.bArgs() != null && co.bArgs().length > 0) { + sb.append("\n\tbArgs=").append(Arrays.toString(co.bArgs())); + } + if(co.tArgs() != null && co.tArgs().length > 0) { + sb.append("\n\ttArgs=").append(Arrays.toString(co.tArgs())); + } + INDArray[] inputs = co.inputArguments(); + INDArray[] outputs = co.outputArguments(); + if(inputs != null ) { + for (int i = 0; i < inputs.length; i++) { + sb.append("\n\tInput[").append(i).append("]=").append(inputs[i].shapeInfoToString()); + } + } + if(outputs != null ) { + for (int i = 0; i < outputs.length; i++) { + sb.append("\n\tOutputs[").append(i).append("]=").append(outputs[i].shapeInfoToString()); + } + } + } else { + if(lOp.x() != null) { + sb.append("\n\tx: ").append(lOp.x().shapeInfoToString()); + } + if(lOp.y() != null) { + sb.append("\n\ty: ").append(lOp.y().shapeInfoToString()); + } + if(lOp.z() != null) { + sb.append("\n\tz: ").append(lOp.z().shapeInfoToString()); + } + if(lOp instanceof ScalarOp){ + INDArray scalar = ((ScalarOp)lOp).scalar(); + if(scalar != null){ + sb.append("\n\tscalar: ").append(scalar.shapeInfoToString()); + } + } + } + sb.append("\n"); + } else if(printMode == PrintMode.REPRODUCE){ + sb.append("\n"); + if(co != null){ + sb.append("DynamicCustomOp op = new ").append(co.getClass().getName()).append("();\n"); + if(co.iArgs() != null && co.iArgs().length > 0 ){ + sb.append("op.addIArgument(").append(Arrays.toString(co.iArgs()).replaceAll("[\\[\\]]", "")).append(");\n"); + } + if(co.bArgs() != null && co.bArgs().length > 0 ){ + sb.append("op.addBArgument(").append(Arrays.toString(co.bArgs()).replaceAll("[\\[\\]]", "")).append(");\n"); + } + if(co.tArgs() != null && co.tArgs().length > 0 ){ + sb.append("op.addTArgument(").append(Arrays.toString(co.tArgs()).replaceAll("[\\[\\]]", "")).append(");\n"); + } + INDArray[] inputs = co.inputArguments(); + INDArray[] outputs = co.outputArguments(); + if(inputs != null ) { + sb.append("INDArray[] inputs = new INDArray[").append(inputs.length).append("];\n"); + for (int i = 0; i < inputs.length; i++) { + sb.append("inputs[").append(i).append("] = "); + sb.append(createString(inputs[i])) + .append(";\n"); + } + sb.append("op.addInputArgument(inputs);\n"); + } + if(outputs != null ) { + sb.append("INDArray[] outputs = new INDArray[").append(outputs.length).append("];\n"); + for (int i = 0; i < outputs.length; i++) { + sb.append("outputs[").append(i).append("] = "); + sb.append(createString(outputs[i])) + .append(";\n"); + } + sb.append("op.addOutputArgument(outputs);\n"); + } + } else { + sb.append("Op op = new ").append(op.getClass().getName()).append("();\n"); + if(lOp.x() != null) { + sb.append("op.setX(").append(createString(lOp.x())).append(");\n"); + } + if(lOp.y() != null) { + sb.append("op.setY(").append(createString(lOp.y())).append(");\n"); + } + if(lOp.z() != null) { + sb.append("op.setZ").append(createString(lOp.z())).append(");\n"); + } + if(lOp instanceof ScalarOp){ + INDArray scalar = ((ScalarOp)lOp).scalar(); + if(scalar != null){ + sb.append("((ScalarOp)op).setScalar(").append(createString(scalar)).append(");\n"); + } + } + } + sb.append("Nd4j.exec(op);\n"); + } + + System.out.print(sb.toString()); + } + + private static String createString(INDArray arr){ + StringBuilder sb = new StringBuilder(); + + if(arr.isEmpty()){ + sb.append("Nd4j.empty(DataType.").append(arr.dataType()).append(");"); + } else { + sb.append("Nd4j.createFromArray("); + + DataType dt = arr.dataType(); + switch (dt){ + case DOUBLE: + double[] dArr = arr.dup().data().asDouble(); + sb.append(Arrays.toString(dArr).replaceAll("[\\[\\]]", "")); + break; + case FLOAT: + case HALF: + case BFLOAT16: + float[] fArr = arr.dup().data().asFloat(); + sb.append(Arrays.toString(fArr) + .replaceAll(",", "f,") + .replaceAll("]", "f") + .replaceAll("[\\[\\]]", "")); + break; + case LONG: + case UINT32: + case UINT64: + long[] lArr = arr.dup().data().asLong(); + sb.append(Arrays.toString(lArr) + .replaceAll(",", "L,") + .replaceAll("]", "L") + .replaceAll("[\\[\\]]", "")); + break; + case INT: + case SHORT: + case UBYTE: + case BYTE: + case UINT16: + case BOOL: + int[] iArr = arr.dup().data().asInt(); + sb.append(Arrays.toString(iArr).replaceAll("[\\[\\]]", "")); + break; + case UTF8: + break; + case COMPRESSED: + case UNKNOWN: + break; + } + + sb.append(").reshape(").append(Arrays.toString(arr.shape()).replaceAll("[\\[\\]]", "")) + .append(")"); + + if(dt == DataType.HALF || dt == DataType.BFLOAT16 || dt == DataType.UINT32 || dt == DataType.UINT64 || + dt == DataType.SHORT || dt == DataType.UBYTE || dt == DataType.BYTE || dt == DataType.UINT16 || dt == DataType.BOOL){ + sb.append(".cast(DataType.").append(arr.dataType()).append(")"); + } + } + + return sb.toString(); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java index ec37deeb8..1231fcb37 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java @@ -235,7 +235,18 @@ public class DifferentialFunctionClassHolder { //log.debug("Missing " + set.size() + " ops!"); countTotalTfOps = tensorflowOpDescriptors.size(); - countTotalMappedOps = nodeConverters.size(); + + //Work out total number of TF ops mapped + Set tfMappedOps = new HashSet<>(); + for(DifferentialFunction df : nodeConverters.values()){ + try{ + String[] tfNames = df.tensorflowNames(); + Collections.addAll(tfMappedOps, tfNames); + } catch (NoOpNameFoundException e){ + //Ignore + } + } + countTotalMappedOps = tfMappedOps.size(); //Get custom ops - map from hash to class Map descriptorMap = Nd4j.getExecutioner().getCustomOperations(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java new file mode 100644 index 000000000..35003ed1f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java @@ -0,0 +1,61 @@ +package org.nd4j.autodiff.samediff.listeners; + +import org.junit.Test; +import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.TrainingConfig; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.learning.config.Adam; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class ExecDebuggingListenerTest extends BaseNd4jTest { + + public ExecDebuggingListenerTest(Nd4jBackend backend) { + super(backend); + } + + @Test + public void testExecDebugListener(){ + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, 1, 2); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 3, 2)); + SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 2)); + SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b)); + SDVariable loss = sd.loss.logLoss("loss", label, sm); + + INDArray i = Nd4j.rand(DataType.FLOAT, 1, 3); + INDArray l = Nd4j.rand(DataType.FLOAT, 1, 2); + + sd.setTrainingConfig(TrainingConfig.builder() + .dataSetFeatureMapping("in") + .dataSetLabelMapping("label") + .updater(new Adam(0.001)) + .build()); + + for(ExecDebuggingListener.PrintMode pm : ExecDebuggingListener.PrintMode.values()){ + sd.setListeners(new ExecDebuggingListener(pm, -1, true)); +// sd.output(m, "softmax"); + sd.fit(new DataSet(i, l)); + + System.out.println("\n\n\n"); + } + + } + + + @Override + public char ordering() { + return 'c'; + } +}