diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java index 3ee1bcd6c..200f55071 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java @@ -157,8 +157,8 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { @Test public void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() { - - DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345)); + Nd4j.getRandom().setSeed(12345); + DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1)); MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() .seed(12345) @@ -194,8 +194,9 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { @Test public void testComputationGraphFrozenLayerParamsAfterBackprop() { + Nd4j.getRandom().setSeed(12345); - DataSet randomData = new DataSet(Nd4j.rand(100, 4,12345), Nd4j.rand(100, 1, 12345)); + DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1)); String frozenBranchName = "B1-"; String unfrozenBranchName = "B2-"; @@ -254,43 +255,18 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { */ @Test public void testFrozenLayerVsSgd() { - DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345)); + Nd4j.getRandom().setSeed(12345); + DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1)); MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder() .seed(12345) .weightInit(WeightInit.XAVIER) .updater(new Sgd(2)) .list() - .layer(0, - new DenseLayer.Builder() - .nIn(4) - .nOut(3) - .build() - ) - .layer(1, - new DenseLayer.Builder() - .updater(new Sgd(0.0)) - .biasUpdater(new Sgd(0.0)) - .nIn(3) - .nOut(4) - .build() - ).layer(2, - new DenseLayer.Builder() - .updater(new Sgd(0.0)) - .biasUpdater(new Sgd(0.0)) - .nIn(4) - .nOut(2) - .build() - - ).layer(3, - new OutputLayer.Builder(LossFunctions.LossFunction.MSE) - .updater(new Sgd(0.0)) - .biasUpdater(new Sgd(0.0)) - .activation(Activation.TANH) - .nIn(2) - .nOut(1) - .build() - ) + .layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build()) + .layer(1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build()) + .layer(2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build()) + .layer(3,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(2).nOut(1).build()) .build(); MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder() @@ -298,36 +274,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { .weightInit(WeightInit.XAVIER) .updater(new Sgd(2)) .list() - .layer(0, - new DenseLayer.Builder() - .nIn(4) - .nOut(3) - .build() - ) - .layer(1, - new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder() - .nIn(3) - .nOut(4) - .build() - ) - ) - .layer(2, - new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder() - .nIn(4) - .nOut(2) - .build() - ) - ).layer(3, - new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new OutputLayer.Builder(LossFunctions.LossFunction.MSE) - .activation(Activation.TANH) - .nIn(2) - .nOut(1) - .build() - ) - ) + .layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build()) + .layer(1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build())) + .layer(2,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build())) + .layer(3,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build())) .build(); MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen); frozenNetwork.init(); @@ -359,8 +309,8 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { @Test public void testComputationGraphVsSgd() { - - DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345)); + Nd4j.getRandom().setSeed(12345); + DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1)); String frozenBranchName = "B1-"; String unfrozenBranchName = "B2-"; @@ -381,71 +331,19 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { .seed(12345) .graphBuilder() .addInputs("input") - .addLayer(initialLayer, - new DenseLayer.Builder() - .nIn(4) - .nOut(4) - .build(), - "input" - ) - .addLayer(frozenBranchUnfrozenLayer0, - new DenseLayer.Builder() - .nIn(4) - .nOut(3) - .build(), - initialLayer - ) - .addLayer(frozenBranchFrozenLayer1, - new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder() - .nIn(3) - .nOut(4) - .build() - ), - frozenBranchUnfrozenLayer0 - ) + .addLayer(initialLayer,new DenseLayer.Builder().nIn(4).nOut(4).build(),"input") + .addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer) + .addLayer(frozenBranchFrozenLayer1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( + new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0) .addLayer(frozenBranchFrozenLayer2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder() - .nIn(4) - .nOut(2) - .build() - ), - frozenBranchFrozenLayer1 - ) - .addLayer(unfrozenLayer0, - new DenseLayer.Builder() - .nIn(4) - .nOut(4) - .build(), - initialLayer - ) - .addLayer(unfrozenLayer1, - new DenseLayer.Builder() - .nIn(4) - .nOut(2) - .build(), - unfrozenLayer0 - ) - .addLayer(unfrozenBranch2, - new DenseLayer.Builder() - .nIn(2) - .nOut(1) - .build(), - unfrozenLayer1 - ) - .addVertex("merge", - new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2) - .addLayer(frozenBranchOutput, - new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new OutputLayer.Builder(LossFunctions.LossFunction.MSE) - .activation(Activation.TANH) - .nIn(3) - .nOut(1) - .build() - ), - "merge" - ) + new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1) + .addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer) + .addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0) + .addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1) + .addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2) + .addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( + new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge") .setOutputs(frozenBranchOutput) .build(); @@ -454,73 +352,15 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { .seed(12345) .graphBuilder() .addInputs("input") - .addLayer(initialLayer, - new DenseLayer.Builder() - .nIn(4) - .nOut(4) - .build(), - "input" - ) - .addLayer(frozenBranchUnfrozenLayer0, - new DenseLayer.Builder() - .nIn(4) - .nOut(3) - .build(), - initialLayer - ) - .addLayer(frozenBranchFrozenLayer1, - new DenseLayer.Builder() - .updater(new Sgd(0.0)) - .biasUpdater(new Sgd(0.0)) - .nIn(3) - .nOut(4) - .build(), - frozenBranchUnfrozenLayer0 - ) - .addLayer(frozenBranchFrozenLayer2, - new DenseLayer.Builder() - .updater(new Sgd(0.0)) - .biasUpdater(new Sgd(0.0)) - .nIn(4) - .nOut(2) - .build() - , - frozenBranchFrozenLayer1 - ) - .addLayer(unfrozenLayer0, - new DenseLayer.Builder() - .nIn(4) - .nOut(4) - .build(), - initialLayer - ) - .addLayer(unfrozenLayer1, - new DenseLayer.Builder() - .nIn(4) - .nOut(2) - .build(), - unfrozenLayer0 - ) - .addLayer(unfrozenBranch2, - new DenseLayer.Builder() - .nIn(2) - .nOut(1) - .build(), - unfrozenLayer1 - ) - .addVertex("merge", - new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2) - .addLayer(frozenBranchOutput, - new OutputLayer.Builder(LossFunctions.LossFunction.MSE) - .updater(new Sgd(0.0)) - .biasUpdater(new Sgd(0.0)) - .activation(Activation.TANH) - .nIn(3) - .nOut(1) - .build() - , - "merge" - ) + .addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input") + .addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer) + .addLayer(frozenBranchFrozenLayer1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build(),frozenBranchUnfrozenLayer0) + .addLayer(frozenBranchFrozenLayer2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build(),frozenBranchFrozenLayer1) + .addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer) + .addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0) + .addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1) + .addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2) + .addLayer(frozenBranchOutput,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(3).nOut(1).build(),"merge") .setOutputs(frozenBranchOutput) .build(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java index b78a06093..9570166ed 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java @@ -172,8 +172,8 @@ public class CompareTrainingImplementations extends BaseDL4JTest { Map placeholders = new HashMap<>(); placeholders.put("input", f); placeholders.put("label", l); - sd.exec(placeholders, lossMse.getVarName()); - INDArray outSd = a1.getArr(); + Map map = sd.output(placeholders, lossMse.getVarName(), a1.getVarName()); + INDArray outSd = map.get(a1.getVarName()); INDArray outDl4j = net.output(f); assertEquals(testName, outDl4j, outSd); @@ -187,7 +187,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest { //Check score double scoreDl4j = net.score(); - double scoreSd = lossMse.getArr().getDouble(0) + sd.calcRegularizationScore(); + double scoreSd = map.get(lossMse.getVarName()).getDouble(0) + sd.calcRegularizationScore(); assertEquals(testName, scoreDl4j, scoreSd, 1e-6); double lossRegScoreSD = sd.calcRegularizationScore(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java index 4787d1082..fc805f0ca 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java @@ -145,7 +145,7 @@ public class LocallyConnected1D extends SameDiffLayer { val weightsShape = new long[] {outputSize, featureDim, nOut}; params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape); if (hasBias) { - val biasShape = new long[] {1, nOut}; + val biasShape = new long[] {nOut}; params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape); } } @@ -200,7 +200,7 @@ public class LocallyConnected1D extends SameDiffLayer { if (hasBias) { SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); - SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b); + SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b, true); return activation.asSameDiff("out", sameDiff, biasAddedResult); } else { return activation.asSameDiff("out", sameDiff, result); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java index 5426fda9b..ef07c9dc5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java @@ -145,7 +145,7 @@ public class LocallyConnected2D extends SameDiffLayer { val weightsShape = new long[] {outputSize[0] * outputSize[1], featureDim, nOut}; params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape); if (hasBias) { - val biasShape = new long[] {1, nOut}; + val biasShape = new long[] {nOut}; params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape); } } @@ -211,7 +211,7 @@ public class LocallyConnected2D extends SameDiffLayer { if (hasBias) { SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); - SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b); + SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b, true); return activation.asSameDiff("out", sameDiff, biasAddedResult); } else { return activation.asSameDiff("out", sameDiff, permutedResult); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java index 73cd7db4d..cb58a9813 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java @@ -114,7 +114,7 @@ public class MergeVertex extends BaseGraphVertex { } try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){ - return Nd4j.hstack(in); + return Nd4j.concat(1, in); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java index f1f4b536d..34799d6ad 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java @@ -134,6 +134,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex { Gradient g = new DefaultGradient(); INDArray[] dLdIns; + boolean[] noClose = new boolean[getNumInputArrays()]; try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ if(sameDiff == null){ doInit(); @@ -167,20 +168,21 @@ public class SameDiffGraphVertex extends BaseGraphVertex { //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration //TODO Find a more efficient solution for this + List required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated for (Map.Entry e : paramTable.entrySet()) { INDArray arr = e.getValue(); sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); } - List required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated - for(String s : inputNames){ - required.add(sameDiff.getVariable(s).gradient().getVarName()); - } - sameDiff.execBackwards(phMap, required); + required.addAll(paramTable.keySet()); + required.addAll(inputNames); + + Map gradsMap = sameDiff.calculateGradients(phMap, required); for(String s : paramTable.keySet() ){ - INDArray sdGrad = sameDiff.grad(s).getArr(); + INDArray sdGrad = gradsMap.get(s); INDArray dl4jGrad = gradTable.get(s); dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS + sdGrad.close(); //TODO optimize this g.gradientForVariable().put(s, dl4jGrad); } @@ -195,13 +197,18 @@ public class SameDiffGraphVertex extends BaseGraphVertex { //Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders // So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here dLdIns[j] = epsilon; + noClose[j] = true; } } } //TODO optimize for( int i=0; i { sameDiff.clearPlaceholders(true); sameDiff.clearOpInputs(); - return workspaceMgr.dup(ArrayType.ACTIVATIONS, result); + INDArray ret = workspaceMgr.dup(ArrayType.ACTIVATIONS, result); + if(!result.isAttached() && result.closeable()) { + //May be attached in rare edge case - for identity, or if gradients are passed through from output to input + // unchaned, as in identity, add scalar, etc + result.close(); + } + return ret; } } @@ -122,6 +128,7 @@ public class SameDiffLayer extends AbstractLayer { Gradient g = new DefaultGradient(); INDArray dLdIn; + boolean noCloseEps = false; try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ if(sameDiff == null){ doInit(); @@ -151,26 +158,25 @@ public class SameDiffLayer extends AbstractLayer { } List requiredGrads = new ArrayList<>(paramTable.size() + 1); - requiredGrads.add(sameDiff.grad(INPUT_KEY).getVarName()); - for(String s : paramTable.keySet()){ - requiredGrads.add(sameDiff.grad(s).getVarName()); - } + requiredGrads.add(INPUT_KEY); + requiredGrads.addAll(paramTable.keySet()); - sameDiff.execBackwards(phMap, requiredGrads); + Map m = sameDiff.calculateGradients(phMap, requiredGrads); for(String s : paramTable.keySet() ){ - INDArray sdGrad = sameDiff.grad(s).getArr(); + INDArray sdGrad = m.get(s); INDArray dl4jGrad = gradTable.get(s); dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS g.gradientForVariable().put(s, dl4jGrad); + sdGrad.close(); } - SDVariable v = sameDiff.grad(INPUT_KEY); - dLdIn = v.getArr(); + dLdIn = m.get(INPUT_KEY); - if(dLdIn == null && fn.getGradPlaceholderName().equals(v.getVarName())){ + if(dLdIn == null && fn.getGradPlaceholderName().equals(INPUT_KEY)){ //Edge case with lambda layers like identity: SameDiff doesn't store the placeholders // So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here dLdIn = epsilon; + noCloseEps = true; } } @@ -178,7 +184,12 @@ public class SameDiffLayer extends AbstractLayer { sameDiff.clearPlaceholders(true); sameDiff.clearOpInputs(); - return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS + Pair ret = new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS + if(!noCloseEps && !dLdIn.isAttached() && dLdIn.closeable()) { + //Edge case: identity etc - might just pass gradient array through unchanged + dLdIn.close(); + } + return ret; } /**Returns the parameters of the neural network as a flattened row vector diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs index 0810d2e6e..7fa9722db 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs @@ -106,6 +106,12 @@ public struct FlatNode : IFlatbufferObject #endif public DType[] GetOutputTypesArray() { return __p.__vector_as_array(38); } public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } } + public string ControlDeps(int j) { int o = __p.__offset(42); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } + public int ControlDepsLength { get { int o = __p.__offset(42); return o != 0 ? __p.__vector_len(o) : 0; } } + public string VarControlDeps(int j) { int o = __p.__offset(44); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } + public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } } + public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } + public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } } public static Offset CreateFlatNode(FlatBufferBuilder builder, int id = 0, @@ -126,9 +132,15 @@ public struct FlatNode : IFlatbufferObject VectorOffset outputNamesOffset = default(VectorOffset), StringOffset opNameOffset = default(StringOffset), VectorOffset outputTypesOffset = default(VectorOffset), - Offset scalarOffset = default(Offset)) { - builder.StartObject(19); + Offset scalarOffset = default(Offset), + VectorOffset controlDepsOffset = default(VectorOffset), + VectorOffset varControlDepsOffset = default(VectorOffset), + VectorOffset controlDepForOffset = default(VectorOffset)) { + builder.StartObject(22); FlatNode.AddOpNum(builder, opNum); + FlatNode.AddControlDepFor(builder, controlDepForOffset); + FlatNode.AddVarControlDeps(builder, varControlDepsOffset); + FlatNode.AddControlDeps(builder, controlDepsOffset); FlatNode.AddScalar(builder, scalarOffset); FlatNode.AddOutputTypes(builder, outputTypesOffset); FlatNode.AddOpName(builder, opNameOffset); @@ -150,7 +162,7 @@ public struct FlatNode : IFlatbufferObject return FlatNode.EndFlatNode(builder); } - public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(19); } + public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(22); } public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); } @@ -200,6 +212,18 @@ public struct FlatNode : IFlatbufferObject public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } public static void AddScalar(FlatBufferBuilder builder, Offset scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); } + public static void AddControlDeps(FlatBufferBuilder builder, VectorOffset controlDepsOffset) { builder.AddOffset(19, controlDepsOffset.Value, 0); } + public static VectorOffset CreateControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } + public static VectorOffset CreateControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } + public static void StartControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } + public static void AddVarControlDeps(FlatBufferBuilder builder, VectorOffset varControlDepsOffset) { builder.AddOffset(20, varControlDepsOffset.Value, 0); } + public static VectorOffset CreateVarControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } + public static VectorOffset CreateVarControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } + public static void StartVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } + public static void AddControlDepFor(FlatBufferBuilder builder, VectorOffset controlDepForOffset) { builder.AddOffset(21, controlDepForOffset.Value, 0); } + public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } + public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } + public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } public static Offset EndFlatNode(FlatBufferBuilder builder) { int o = builder.EndObject(); return new Offset(o); diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java index f739551f1..8a72cc00a 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java @@ -66,6 +66,12 @@ public final class FlatNode extends Table { public ByteBuffer outputTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 38, 1); } public FlatArray scalar() { return scalar(new FlatArray()); } public FlatArray scalar(FlatArray obj) { int o = __offset(40); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; } + public String controlDeps(int j) { int o = __offset(42); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepsLength() { int o = __offset(42); return o != 0 ? __vector_len(o) : 0; } + public String varControlDeps(int j) { int o = __offset(44); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; } + public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; } public static int createFlatNode(FlatBufferBuilder builder, int id, @@ -86,9 +92,15 @@ public final class FlatNode extends Table { int outputNamesOffset, int opNameOffset, int outputTypesOffset, - int scalarOffset) { - builder.startObject(19); + int scalarOffset, + int controlDepsOffset, + int varControlDepsOffset, + int controlDepForOffset) { + builder.startObject(22); FlatNode.addOpNum(builder, opNum); + FlatNode.addControlDepFor(builder, controlDepForOffset); + FlatNode.addVarControlDeps(builder, varControlDepsOffset); + FlatNode.addControlDeps(builder, controlDepsOffset); FlatNode.addScalar(builder, scalarOffset); FlatNode.addOutputTypes(builder, outputTypesOffset); FlatNode.addOpName(builder, opNameOffset); @@ -110,7 +122,7 @@ public final class FlatNode extends Table { return FlatNode.endFlatNode(builder); } - public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(19); } + public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); } public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); } @@ -150,6 +162,15 @@ public final class FlatNode extends Table { public static int createOutputTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); } public static void startOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); } public static void addScalar(FlatBufferBuilder builder, int scalarOffset) { builder.addOffset(18, scalarOffset, 0); } + public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(19, controlDepsOffset, 0); } + public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addVarControlDeps(FlatBufferBuilder builder, int varControlDepsOffset) { builder.addOffset(20, varControlDepsOffset, 0); } + public static int createVarControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); } + public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endFlatNode(FlatBufferBuilder builder) { int o = builder.endObject(); return o; diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py index 520fe1aad..889eca62f 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py @@ -294,7 +294,52 @@ class FlatNode(object): return obj return None -def FlatNodeStart(builder): builder.StartObject(19) + # FlatNode + def ControlDeps(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(42)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # FlatNode + def ControlDepsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(42)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # FlatNode + def VarControlDeps(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(44)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # FlatNode + def VarControlDepsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(44)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # FlatNode + def ControlDepFor(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(46)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # FlatNode + def ControlDepForLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(46)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + +def FlatNodeStart(builder): builder.StartObject(22) def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0) def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0) @@ -324,4 +369,10 @@ def FlatNodeAddOpName(builder, opName): builder.PrependUOffsetTRelativeSlot(16, def FlatNodeAddOutputTypes(builder, outputTypes): builder.PrependUOffsetTRelativeSlot(17, flatbuffers.number_types.UOffsetTFlags.py_type(outputTypes), 0) def FlatNodeStartOutputTypesVector(builder, numElems): return builder.StartVector(1, numElems, 1) def FlatNodeAddScalar(builder, scalar): builder.PrependUOffsetTRelativeSlot(18, flatbuffers.number_types.UOffsetTFlags.py_type(scalar), 0) +def FlatNodeAddControlDeps(builder, controlDeps): builder.PrependUOffsetTRelativeSlot(19, flatbuffers.number_types.UOffsetTFlags.py_type(controlDeps), 0) +def FlatNodeStartControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def FlatNodeAddVarControlDeps(builder, varControlDeps): builder.PrependUOffsetTRelativeSlot(20, flatbuffers.number_types.UOffsetTFlags.py_type(varControlDeps), 0) +def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0) +def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4) def FlatNodeEnd(builder): return builder.EndObject() diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs index 9764668a0..325094654 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs @@ -37,6 +37,12 @@ public struct FlatVariable : IFlatbufferObject public FlatArray? Ndarray { get { int o = __p.__offset(12); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } } public int Device { get { int o = __p.__offset(14); return o != 0 ? __p.bb.GetInt(o + __p.bb_pos) : (int)0; } } public VarType Variabletype { get { int o = __p.__offset(16); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } } + public string ControlDeps(int j) { int o = __p.__offset(18); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } + public int ControlDepsLength { get { int o = __p.__offset(18); return o != 0 ? __p.__vector_len(o) : 0; } } + public string ControlDepForOp(int j) { int o = __p.__offset(20); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } + public int ControlDepForOpLength { get { int o = __p.__offset(20); return o != 0 ? __p.__vector_len(o) : 0; } } + public string ControlDepsForVar(int j) { int o = __p.__offset(22); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } + public int ControlDepsForVarLength { get { int o = __p.__offset(22); return o != 0 ? __p.__vector_len(o) : 0; } } public static Offset CreateFlatVariable(FlatBufferBuilder builder, Offset idOffset = default(Offset), @@ -45,8 +51,14 @@ public struct FlatVariable : IFlatbufferObject VectorOffset shapeOffset = default(VectorOffset), Offset ndarrayOffset = default(Offset), int device = 0, - VarType variabletype = VarType.VARIABLE) { - builder.StartObject(7); + VarType variabletype = VarType.VARIABLE, + VectorOffset controlDepsOffset = default(VectorOffset), + VectorOffset controlDepForOpOffset = default(VectorOffset), + VectorOffset controlDepsForVarOffset = default(VectorOffset)) { + builder.StartObject(10); + FlatVariable.AddControlDepsForVar(builder, controlDepsForVarOffset); + FlatVariable.AddControlDepForOp(builder, controlDepForOpOffset); + FlatVariable.AddControlDeps(builder, controlDepsOffset); FlatVariable.AddDevice(builder, device); FlatVariable.AddNdarray(builder, ndarrayOffset); FlatVariable.AddShape(builder, shapeOffset); @@ -57,7 +69,7 @@ public struct FlatVariable : IFlatbufferObject return FlatVariable.EndFlatVariable(builder); } - public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); } + public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(10); } public static void AddId(FlatBufferBuilder builder, Offset idOffset) { builder.AddOffset(0, idOffset.Value, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } @@ -68,6 +80,18 @@ public struct FlatVariable : IFlatbufferObject public static void AddNdarray(FlatBufferBuilder builder, Offset ndarrayOffset) { builder.AddOffset(4, ndarrayOffset.Value, 0); } public static void AddDevice(FlatBufferBuilder builder, int device) { builder.AddInt(5, device, 0); } public static void AddVariabletype(FlatBufferBuilder builder, VarType variabletype) { builder.AddSbyte(6, (sbyte)variabletype, 0); } + public static void AddControlDeps(FlatBufferBuilder builder, VectorOffset controlDepsOffset) { builder.AddOffset(7, controlDepsOffset.Value, 0); } + public static VectorOffset CreateControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } + public static VectorOffset CreateControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } + public static void StartControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } + public static void AddControlDepForOp(FlatBufferBuilder builder, VectorOffset controlDepForOpOffset) { builder.AddOffset(8, controlDepForOpOffset.Value, 0); } + public static VectorOffset CreateControlDepForOpVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } + public static VectorOffset CreateControlDepForOpVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } + public static void StartControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } + public static void AddControlDepsForVar(FlatBufferBuilder builder, VectorOffset controlDepsForVarOffset) { builder.AddOffset(9, controlDepsForVarOffset.Value, 0); } + public static VectorOffset CreateControlDepsForVarVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } + public static VectorOffset CreateControlDepsForVarVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } + public static void StartControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } public static Offset EndFlatVariable(FlatBufferBuilder builder) { int o = builder.EndObject(); return new Offset(o); diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.java b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.java index 37e2053c2..d73c990bb 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.java +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.java @@ -28,6 +28,12 @@ public final class FlatVariable extends Table { public FlatArray ndarray(FlatArray obj) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; } public int device() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; } public byte variabletype() { int o = __offset(16); return o != 0 ? bb.get(o + bb_pos) : 0; } + public String controlDeps(int j) { int o = __offset(18); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepsLength() { int o = __offset(18); return o != 0 ? __vector_len(o) : 0; } + public String controlDepForOp(int j) { int o = __offset(20); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepForOpLength() { int o = __offset(20); return o != 0 ? __vector_len(o) : 0; } + public String controlDepsForVar(int j) { int o = __offset(22); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepsForVarLength() { int o = __offset(22); return o != 0 ? __vector_len(o) : 0; } public static int createFlatVariable(FlatBufferBuilder builder, int idOffset, @@ -36,8 +42,14 @@ public final class FlatVariable extends Table { int shapeOffset, int ndarrayOffset, int device, - byte variabletype) { - builder.startObject(7); + byte variabletype, + int controlDepsOffset, + int controlDepForOpOffset, + int controlDepsForVarOffset) { + builder.startObject(10); + FlatVariable.addControlDepsForVar(builder, controlDepsForVarOffset); + FlatVariable.addControlDepForOp(builder, controlDepForOpOffset); + FlatVariable.addControlDeps(builder, controlDepsOffset); FlatVariable.addDevice(builder, device); FlatVariable.addNdarray(builder, ndarrayOffset); FlatVariable.addShape(builder, shapeOffset); @@ -48,7 +60,7 @@ public final class FlatVariable extends Table { return FlatVariable.endFlatVariable(builder); } - public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(7); } + public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(10); } public static void addId(FlatBufferBuilder builder, int idOffset) { builder.addOffset(0, idOffset, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addDtype(FlatBufferBuilder builder, byte dtype) { builder.addByte(2, dtype, 0); } @@ -58,6 +70,15 @@ public final class FlatVariable extends Table { public static void addNdarray(FlatBufferBuilder builder, int ndarrayOffset) { builder.addOffset(4, ndarrayOffset, 0); } public static void addDevice(FlatBufferBuilder builder, int device) { builder.addInt(5, device, 0); } public static void addVariabletype(FlatBufferBuilder builder, byte variabletype) { builder.addByte(6, variabletype, 0); } + public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(7, controlDepsOffset, 0); } + public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addControlDepForOp(FlatBufferBuilder builder, int controlDepForOpOffset) { builder.addOffset(8, controlDepForOpOffset, 0); } + public static int createControlDepForOpVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addControlDepsForVar(FlatBufferBuilder builder, int controlDepsForVarOffset) { builder.addOffset(9, controlDepsForVarOffset, 0); } + public static int createControlDepsForVarVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endFlatVariable(FlatBufferBuilder builder) { int o = builder.endObject(); return o; diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.py b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.py index e2679c6cd..d0036c247 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.py +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.py @@ -90,7 +90,52 @@ class FlatVariable(object): return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) return 0 -def FlatVariableStart(builder): builder.StartObject(7) + # FlatVariable + def ControlDeps(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # FlatVariable + def ControlDepsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # FlatVariable + def ControlDepForOp(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # FlatVariable + def ControlDepForOpLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # FlatVariable + def ControlDepsForVar(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # FlatVariable + def ControlDepsForVarLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + +def FlatVariableStart(builder): builder.StartObject(10) def FlatVariableAddId(builder, id): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(id), 0) def FlatVariableAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) def FlatVariableAddDtype(builder, dtype): builder.PrependInt8Slot(2, dtype, 0) @@ -99,4 +144,10 @@ def FlatVariableStartShapeVector(builder, numElems): return builder.StartVector( def FlatVariableAddNdarray(builder, ndarray): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(ndarray), 0) def FlatVariableAddDevice(builder, device): builder.PrependInt32Slot(5, device, 0) def FlatVariableAddVariabletype(builder, variabletype): builder.PrependInt8Slot(6, variabletype, 0) +def FlatVariableAddControlDeps(builder, controlDeps): builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(controlDeps), 0) +def FlatVariableStartControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def FlatVariableAddControlDepForOp(builder, controlDepForOp): builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepForOp), 0) +def FlatVariableStartControlDepForOpVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def FlatVariableAddControlDepsForVar(builder, controlDepsForVar): builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepsForVar), 0) +def FlatVariableStartControlDepsForVarVector(builder, numElems): return builder.StartVector(4, numElems, 4) def FlatVariableEnd(builder): return builder.EndObject() diff --git a/libnd4j/include/graph/generated/node_generated.h b/libnd4j/include/graph/generated/node_generated.h index 286547552..6ca85f7b0 100644 --- a/libnd4j/include/graph/generated/node_generated.h +++ b/libnd4j/include/graph/generated/node_generated.h @@ -35,7 +35,10 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_OUTPUTNAMES = 34, VT_OPNAME = 36, VT_OUTPUTTYPES = 38, - VT_SCALAR = 40 + VT_SCALAR = 40, + VT_CONTROLDEPS = 42, + VT_VARCONTROLDEPS = 44, + VT_CONTROLDEPFOR = 46 }; int32_t id() const { return GetField(VT_ID, 0); @@ -94,6 +97,15 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const FlatArray *scalar() const { return GetPointer(VT_SCALAR); } + const flatbuffers::Vector> *controlDeps() const { + return GetPointer> *>(VT_CONTROLDEPS); + } + const flatbuffers::Vector> *varControlDeps() const { + return GetPointer> *>(VT_VARCONTROLDEPS); + } + const flatbuffers::Vector> *controlDepFor() const { + return GetPointer> *>(VT_CONTROLDEPFOR); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_ID) && @@ -132,6 +144,15 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyVector(outputTypes()) && VerifyOffset(verifier, VT_SCALAR) && verifier.VerifyTable(scalar()) && + VerifyOffset(verifier, VT_CONTROLDEPS) && + verifier.VerifyVector(controlDeps()) && + verifier.VerifyVectorOfStrings(controlDeps()) && + VerifyOffset(verifier, VT_VARCONTROLDEPS) && + verifier.VerifyVector(varControlDeps()) && + verifier.VerifyVectorOfStrings(varControlDeps()) && + VerifyOffset(verifier, VT_CONTROLDEPFOR) && + verifier.VerifyVector(controlDepFor()) && + verifier.VerifyVectorOfStrings(controlDepFor()) && verifier.EndTable(); } }; @@ -196,6 +217,15 @@ struct FlatNodeBuilder { void add_scalar(flatbuffers::Offset scalar) { fbb_.AddOffset(FlatNode::VT_SCALAR, scalar); } + void add_controlDeps(flatbuffers::Offset>> controlDeps) { + fbb_.AddOffset(FlatNode::VT_CONTROLDEPS, controlDeps); + } + void add_varControlDeps(flatbuffers::Offset>> varControlDeps) { + fbb_.AddOffset(FlatNode::VT_VARCONTROLDEPS, varControlDeps); + } + void add_controlDepFor(flatbuffers::Offset>> controlDepFor) { + fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor); + } explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -228,9 +258,15 @@ inline flatbuffers::Offset CreateFlatNode( flatbuffers::Offset>> outputNames = 0, flatbuffers::Offset opName = 0, flatbuffers::Offset> outputTypes = 0, - flatbuffers::Offset scalar = 0) { + flatbuffers::Offset scalar = 0, + flatbuffers::Offset>> controlDeps = 0, + flatbuffers::Offset>> varControlDeps = 0, + flatbuffers::Offset>> controlDepFor = 0) { FlatNodeBuilder builder_(_fbb); builder_.add_opNum(opNum); + builder_.add_controlDepFor(controlDepFor); + builder_.add_varControlDeps(varControlDeps); + builder_.add_controlDeps(controlDeps); builder_.add_scalar(scalar); builder_.add_outputTypes(outputTypes); builder_.add_opName(opName); @@ -272,7 +308,10 @@ inline flatbuffers::Offset CreateFlatNodeDirect( const std::vector> *outputNames = nullptr, const char *opName = nullptr, const std::vector *outputTypes = nullptr, - flatbuffers::Offset scalar = 0) { + flatbuffers::Offset scalar = 0, + const std::vector> *controlDeps = nullptr, + const std::vector> *varControlDeps = nullptr, + const std::vector> *controlDepFor = nullptr) { return nd4j::graph::CreateFlatNode( _fbb, id, @@ -293,7 +332,10 @@ inline flatbuffers::Offset CreateFlatNodeDirect( outputNames ? _fbb.CreateVector>(*outputNames) : 0, opName ? _fbb.CreateString(opName) : 0, outputTypes ? _fbb.CreateVector(*outputTypes) : 0, - scalar); + scalar, + controlDeps ? _fbb.CreateVector>(*controlDeps) : 0, + varControlDeps ? _fbb.CreateVector>(*varControlDeps) : 0, + controlDepFor ? _fbb.CreateVector>(*controlDepFor) : 0); } inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) { diff --git a/libnd4j/include/graph/generated/node_generated.js b/libnd4j/include/graph/generated/node_generated.js index bd2274dad..dd83c4356 100644 --- a/libnd4j/include/graph/generated/node_generated.js +++ b/libnd4j/include/graph/generated/node_generated.js @@ -344,11 +344,65 @@ nd4j.graph.FlatNode.prototype.scalar = function(obj) { return offset ? (obj || new nd4j.graph.FlatArray).__init(this.bb.__indirect(this.bb_pos + offset), this.bb) : null; }; +/** + * @param {number} index + * @param {flatbuffers.Encoding=} optionalEncoding + * @returns {string|Uint8Array} + */ +nd4j.graph.FlatNode.prototype.controlDeps = function(index, optionalEncoding) { + var offset = this.bb.__offset(this.bb_pos, 42); + return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; +}; + +/** + * @returns {number} + */ +nd4j.graph.FlatNode.prototype.controlDepsLength = function() { + var offset = this.bb.__offset(this.bb_pos, 42); + return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; +}; + +/** + * @param {number} index + * @param {flatbuffers.Encoding=} optionalEncoding + * @returns {string|Uint8Array} + */ +nd4j.graph.FlatNode.prototype.varControlDeps = function(index, optionalEncoding) { + var offset = this.bb.__offset(this.bb_pos, 44); + return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; +}; + +/** + * @returns {number} + */ +nd4j.graph.FlatNode.prototype.varControlDepsLength = function() { + var offset = this.bb.__offset(this.bb_pos, 44); + return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; +}; + +/** + * @param {number} index + * @param {flatbuffers.Encoding=} optionalEncoding + * @returns {string|Uint8Array} + */ +nd4j.graph.FlatNode.prototype.controlDepFor = function(index, optionalEncoding) { + var offset = this.bb.__offset(this.bb_pos, 46); + return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; +}; + +/** + * @returns {number} + */ +nd4j.graph.FlatNode.prototype.controlDepForLength = function() { + var offset = this.bb.__offset(this.bb_pos, 46); + return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; +}; + /** * @param {flatbuffers.Builder} builder */ nd4j.graph.FlatNode.startFlatNode = function(builder) { - builder.startObject(19); + builder.startObject(22); }; /** @@ -713,6 +767,93 @@ nd4j.graph.FlatNode.addScalar = function(builder, scalarOffset) { builder.addFieldOffset(18, scalarOffset, 0); }; +/** + * @param {flatbuffers.Builder} builder + * @param {flatbuffers.Offset} controlDepsOffset + */ +nd4j.graph.FlatNode.addControlDeps = function(builder, controlDepsOffset) { + builder.addFieldOffset(19, controlDepsOffset, 0); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {Array.} data + * @returns {flatbuffers.Offset} + */ +nd4j.graph.FlatNode.createControlDepsVector = function(builder, data) { + builder.startVector(4, data.length, 4); + for (var i = data.length - 1; i >= 0; i--) { + builder.addOffset(data[i]); + } + return builder.endVector(); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {number} numElems + */ +nd4j.graph.FlatNode.startControlDepsVector = function(builder, numElems) { + builder.startVector(4, numElems, 4); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {flatbuffers.Offset} varControlDepsOffset + */ +nd4j.graph.FlatNode.addVarControlDeps = function(builder, varControlDepsOffset) { + builder.addFieldOffset(20, varControlDepsOffset, 0); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {Array.} data + * @returns {flatbuffers.Offset} + */ +nd4j.graph.FlatNode.createVarControlDepsVector = function(builder, data) { + builder.startVector(4, data.length, 4); + for (var i = data.length - 1; i >= 0; i--) { + builder.addOffset(data[i]); + } + return builder.endVector(); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {number} numElems + */ +nd4j.graph.FlatNode.startVarControlDepsVector = function(builder, numElems) { + builder.startVector(4, numElems, 4); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {flatbuffers.Offset} controlDepForOffset + */ +nd4j.graph.FlatNode.addControlDepFor = function(builder, controlDepForOffset) { + builder.addFieldOffset(21, controlDepForOffset, 0); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {Array.} data + * @returns {flatbuffers.Offset} + */ +nd4j.graph.FlatNode.createControlDepForVector = function(builder, data) { + builder.startVector(4, data.length, 4); + for (var i = data.length - 1; i >= 0; i--) { + builder.addOffset(data[i]); + } + return builder.endVector(); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {number} numElems + */ +nd4j.graph.FlatNode.startControlDepForVector = function(builder, numElems) { + builder.startVector(4, numElems, 4); +}; + /** * @param {flatbuffers.Builder} builder * @returns {flatbuffers.Offset} diff --git a/libnd4j/include/graph/generated/variable_generated.h b/libnd4j/include/graph/generated/variable_generated.h index ca1a705a0..465490722 100644 --- a/libnd4j/include/graph/generated/variable_generated.h +++ b/libnd4j/include/graph/generated/variable_generated.h @@ -57,7 +57,10 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_SHAPE = 10, VT_NDARRAY = 12, VT_DEVICE = 14, - VT_VARIABLETYPE = 16 + VT_VARIABLETYPE = 16, + VT_CONTROLDEPS = 18, + VT_CONTROLDEPFOROP = 20, + VT_CONTROLDEPSFORVAR = 22 }; const IntPair *id() const { return GetPointer(VT_ID); @@ -80,6 +83,15 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VarType variabletype() const { return static_cast(GetField(VT_VARIABLETYPE, 0)); } + const flatbuffers::Vector> *controlDeps() const { + return GetPointer> *>(VT_CONTROLDEPS); + } + const flatbuffers::Vector> *controlDepForOp() const { + return GetPointer> *>(VT_CONTROLDEPFOROP); + } + const flatbuffers::Vector> *controlDepsForVar() const { + return GetPointer> *>(VT_CONTROLDEPSFORVAR); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_ID) && @@ -93,6 +105,15 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyTable(ndarray()) && VerifyField(verifier, VT_DEVICE) && VerifyField(verifier, VT_VARIABLETYPE) && + VerifyOffset(verifier, VT_CONTROLDEPS) && + verifier.VerifyVector(controlDeps()) && + verifier.VerifyVectorOfStrings(controlDeps()) && + VerifyOffset(verifier, VT_CONTROLDEPFOROP) && + verifier.VerifyVector(controlDepForOp()) && + verifier.VerifyVectorOfStrings(controlDepForOp()) && + VerifyOffset(verifier, VT_CONTROLDEPSFORVAR) && + verifier.VerifyVector(controlDepsForVar()) && + verifier.VerifyVectorOfStrings(controlDepsForVar()) && verifier.EndTable(); } }; @@ -121,6 +142,15 @@ struct FlatVariableBuilder { void add_variabletype(VarType variabletype) { fbb_.AddElement(FlatVariable::VT_VARIABLETYPE, static_cast(variabletype), 0); } + void add_controlDeps(flatbuffers::Offset>> controlDeps) { + fbb_.AddOffset(FlatVariable::VT_CONTROLDEPS, controlDeps); + } + void add_controlDepForOp(flatbuffers::Offset>> controlDepForOp) { + fbb_.AddOffset(FlatVariable::VT_CONTROLDEPFOROP, controlDepForOp); + } + void add_controlDepsForVar(flatbuffers::Offset>> controlDepsForVar) { + fbb_.AddOffset(FlatVariable::VT_CONTROLDEPSFORVAR, controlDepsForVar); + } explicit FlatVariableBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -141,8 +171,14 @@ inline flatbuffers::Offset CreateFlatVariable( flatbuffers::Offset> shape = 0, flatbuffers::Offset ndarray = 0, int32_t device = 0, - VarType variabletype = VarType_VARIABLE) { + VarType variabletype = VarType_VARIABLE, + flatbuffers::Offset>> controlDeps = 0, + flatbuffers::Offset>> controlDepForOp = 0, + flatbuffers::Offset>> controlDepsForVar = 0) { FlatVariableBuilder builder_(_fbb); + builder_.add_controlDepsForVar(controlDepsForVar); + builder_.add_controlDepForOp(controlDepForOp); + builder_.add_controlDeps(controlDeps); builder_.add_device(device); builder_.add_ndarray(ndarray); builder_.add_shape(shape); @@ -161,7 +197,10 @@ inline flatbuffers::Offset CreateFlatVariableDirect( const std::vector *shape = nullptr, flatbuffers::Offset ndarray = 0, int32_t device = 0, - VarType variabletype = VarType_VARIABLE) { + VarType variabletype = VarType_VARIABLE, + const std::vector> *controlDeps = nullptr, + const std::vector> *controlDepForOp = nullptr, + const std::vector> *controlDepsForVar = nullptr) { return nd4j::graph::CreateFlatVariable( _fbb, id, @@ -170,7 +209,10 @@ inline flatbuffers::Offset CreateFlatVariableDirect( shape ? _fbb.CreateVector(*shape) : 0, ndarray, device, - variabletype); + variabletype, + controlDeps ? _fbb.CreateVector>(*controlDeps) : 0, + controlDepForOp ? _fbb.CreateVector>(*controlDepForOp) : 0, + controlDepsForVar ? _fbb.CreateVector>(*controlDepsForVar) : 0); } inline const nd4j::graph::FlatVariable *GetFlatVariable(const void *buf) { diff --git a/libnd4j/include/graph/generated/variable_generated.js b/libnd4j/include/graph/generated/variable_generated.js index 9012af2de..4bcdcd741 100644 --- a/libnd4j/include/graph/generated/variable_generated.js +++ b/libnd4j/include/graph/generated/variable_generated.js @@ -125,11 +125,65 @@ nd4j.graph.FlatVariable.prototype.variabletype = function() { return offset ? /** @type {nd4j.graph.VarType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.VarType.VARIABLE; }; +/** + * @param {number} index + * @param {flatbuffers.Encoding=} optionalEncoding + * @returns {string|Uint8Array} + */ +nd4j.graph.FlatVariable.prototype.controlDeps = function(index, optionalEncoding) { + var offset = this.bb.__offset(this.bb_pos, 18); + return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; +}; + +/** + * @returns {number} + */ +nd4j.graph.FlatVariable.prototype.controlDepsLength = function() { + var offset = this.bb.__offset(this.bb_pos, 18); + return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; +}; + +/** + * @param {number} index + * @param {flatbuffers.Encoding=} optionalEncoding + * @returns {string|Uint8Array} + */ +nd4j.graph.FlatVariable.prototype.controlDepForOp = function(index, optionalEncoding) { + var offset = this.bb.__offset(this.bb_pos, 20); + return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; +}; + +/** + * @returns {number} + */ +nd4j.graph.FlatVariable.prototype.controlDepForOpLength = function() { + var offset = this.bb.__offset(this.bb_pos, 20); + return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; +}; + +/** + * @param {number} index + * @param {flatbuffers.Encoding=} optionalEncoding + * @returns {string|Uint8Array} + */ +nd4j.graph.FlatVariable.prototype.controlDepsForVar = function(index, optionalEncoding) { + var offset = this.bb.__offset(this.bb_pos, 22); + return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; +}; + +/** + * @returns {number} + */ +nd4j.graph.FlatVariable.prototype.controlDepsForVarLength = function() { + var offset = this.bb.__offset(this.bb_pos, 22); + return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; +}; + /** * @param {flatbuffers.Builder} builder */ nd4j.graph.FlatVariable.startFlatVariable = function(builder) { - builder.startObject(7); + builder.startObject(10); }; /** @@ -209,6 +263,93 @@ nd4j.graph.FlatVariable.addVariabletype = function(builder, variabletype) { builder.addFieldInt8(6, variabletype, nd4j.graph.VarType.VARIABLE); }; +/** + * @param {flatbuffers.Builder} builder + * @param {flatbuffers.Offset} controlDepsOffset + */ +nd4j.graph.FlatVariable.addControlDeps = function(builder, controlDepsOffset) { + builder.addFieldOffset(7, controlDepsOffset, 0); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {Array.} data + * @returns {flatbuffers.Offset} + */ +nd4j.graph.FlatVariable.createControlDepsVector = function(builder, data) { + builder.startVector(4, data.length, 4); + for (var i = data.length - 1; i >= 0; i--) { + builder.addOffset(data[i]); + } + return builder.endVector(); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {number} numElems + */ +nd4j.graph.FlatVariable.startControlDepsVector = function(builder, numElems) { + builder.startVector(4, numElems, 4); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {flatbuffers.Offset} controlDepForOpOffset + */ +nd4j.graph.FlatVariable.addControlDepForOp = function(builder, controlDepForOpOffset) { + builder.addFieldOffset(8, controlDepForOpOffset, 0); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {Array.} data + * @returns {flatbuffers.Offset} + */ +nd4j.graph.FlatVariable.createControlDepForOpVector = function(builder, data) { + builder.startVector(4, data.length, 4); + for (var i = data.length - 1; i >= 0; i--) { + builder.addOffset(data[i]); + } + return builder.endVector(); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {number} numElems + */ +nd4j.graph.FlatVariable.startControlDepForOpVector = function(builder, numElems) { + builder.startVector(4, numElems, 4); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {flatbuffers.Offset} controlDepsForVarOffset + */ +nd4j.graph.FlatVariable.addControlDepsForVar = function(builder, controlDepsForVarOffset) { + builder.addFieldOffset(9, controlDepsForVarOffset, 0); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {Array.} data + * @returns {flatbuffers.Offset} + */ +nd4j.graph.FlatVariable.createControlDepsForVarVector = function(builder, data) { + builder.startVector(4, data.length, 4); + for (var i = data.length - 1; i >= 0; i--) { + builder.addOffset(data[i]); + } + return builder.endVector(); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {number} numElems + */ +nd4j.graph.FlatVariable.startControlDepsForVarVector = function(builder, numElems) { + builder.startVector(4, numElems, 4); +}; + /** * @param {flatbuffers.Builder} builder * @returns {flatbuffers.Offset} diff --git a/libnd4j/include/graph/scheme/node.fbs b/libnd4j/include/graph/scheme/node.fbs index 930702f6d..92975e216 100644 --- a/libnd4j/include/graph/scheme/node.fbs +++ b/libnd4j/include/graph/scheme/node.fbs @@ -52,6 +52,12 @@ table FlatNode { //Scalar value - used for scalar ops. Should be single value only. scalar:FlatArray; + + //Control dependencies + controlDeps:[string]; + varControlDeps:[string]; + controlDepFor:[string]; + } root_type FlatNode; \ No newline at end of file diff --git a/libnd4j/include/graph/scheme/variable.fbs b/libnd4j/include/graph/scheme/variable.fbs index 31eafafa7..1e8010d43 100644 --- a/libnd4j/include/graph/scheme/variable.fbs +++ b/libnd4j/include/graph/scheme/variable.fbs @@ -37,6 +37,10 @@ table FlatVariable { device:int; // default is -1, which means _auto_ variabletype:VarType; + + controlDeps:[string]; + controlDepForOp:[string]; + controlDepsForVar:[string]; } root_type FlatVariable; \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index eb3424007..5ce25628e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -659,7 +659,8 @@ public abstract class DifferentialFunction { if(sameDiff == null) this.ownName = UUID.randomUUID().toString(); else { - this.ownName = sameDiff.getOpName(opName()); + String n = sameDiff.getOpName(opName()); + this.ownName = n; } if(sameDiff != null) @@ -696,30 +697,11 @@ public abstract class DifferentialFunction { } @JsonIgnore - private INDArray getX() { - INDArray ret = sameDiff.getArrForVarName(args()[0].getVarName()); - return ret; + public INDArray getInputArgument(int index){ + //Subclasses should implement this + throw new UnsupportedOperationException("Not implemented"); } - @JsonIgnore - private INDArray getY() { - if(args().length > 1) { - INDArray ret = sameDiff.getArrForVarName(args()[1].getVarName()); - return ret; - } - return null; - } - - @JsonIgnore - private INDArray getZ() { - if(isInPlace()) - return getX(); - SDVariable opId = outputVariables()[0]; - INDArray ret = opId.getArr(); - return ret; - } - - /** @@ -860,4 +842,8 @@ public abstract class DifferentialFunction { public int getNumOutputs(){return -1;} + /** + * Clear the input and output INDArrays, if any are set + */ + public abstract void clearArrays(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 4b042dded..318ec4478 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -982,8 +982,8 @@ public class DifferentialFunctionFactory { return new CumProdBp(sameDiff(), in, grad, exclusive, reverse, axis).outputVariable(); } - public SDVariable biasAdd(SDVariable input, SDVariable bias) { - return new BiasAdd(sameDiff(), input, bias).outputVariable(); + public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) { + return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable(); } public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java index f0dcecb49..92d1ce120 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java @@ -24,6 +24,7 @@ import lombok.Getter; import org.nd4j.autodiff.listeners.Listener; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.IMetric; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -319,6 +320,7 @@ public class History { * Gets the training evaluations ran during the last epoch */ public EvaluationRecord finalTrainingEvaluations(){ + Preconditions.checkState(!trainingHistory.isEmpty(), "Cannot get final training evaluation - history is empty"); return trainingHistory.get(trainingHistory.size() - 1); } @@ -326,6 +328,7 @@ public class History { * Gets the validation evaluations ran during the last epoch */ public EvaluationRecord finalValidationEvaluations(){ + Preconditions.checkState(!validationHistory.isEmpty(), "Cannot get final validation evaluation - history is empty"); return validationHistory.get(validationHistory.size() - 1); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index a97668e8e..f2818e7e4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -16,34 +16,23 @@ package org.nd4j.autodiff.samediff; -import java.util.Objects; import lombok.*; import lombok.extern.slf4j.Slf4j; -import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.blas.params.MMulTranspose; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.Op; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; import org.nd4j.linalg.api.shape.LongShapeDescriptor; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.weightinit.WeightInitScheme; -import org.nd4j.weightinit.impl.ZeroInitScheme; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import java.util.Map; +import java.util.Objects; /** * @@ -167,6 +156,10 @@ public class SDVariable implements Serializable { if(sameDiff.arrayAlreadyExistsForVarName(getVarName())) return sameDiff.getArrForVarName(getVarName()); + if(variableType == VariableType.ARRAY){ + throw new UnsupportedOperationException("Cannot get array for ARRAY type SDVariable - use SDVariable.exec or SameDiff.output instead"); + } + //initialize value if it's actually a scalar constant (zero or 1 typically...) if(variableType == VariableType.VARIABLE && weightInitScheme != null && shape != null){ INDArray arr = weightInitScheme.create(dataType, shape); @@ -211,8 +204,8 @@ public class SDVariable implements Serializable { * created automatically when training is performed. */ public SDVariable getGradient() { - Preconditions.checkState(dataType().isFPType(), "Cannot get gradient of %s variable \"%s\": only floating" + - " point variables have gradients", getVarName(), dataType()); + Preconditions.checkState(dataType().isFPType(), "Cannot get gradient of %s datatype variable \"%s\": only floating" + + " point variables have gradients", dataType(), getVarName()); return sameDiff.getGradForVariable(getVarName()); } @@ -230,7 +223,7 @@ public class SDVariable implements Serializable { } long[] initialShape = sameDiff.getShapeForVarName(getVarName()); - if(initialShape == null) { + if(initialShape == null && variableType != VariableType.ARRAY) { val arr = getArr(); if(arr != null) return arr.shape(); @@ -254,7 +247,7 @@ public class SDVariable implements Serializable { public DataType dataType() { if(this.dataType == null){ //Try to infer datatype instead of returning null - if(getArr() != null){ + if(variableType != VariableType.ARRAY && getArr() != null){ this.dataType = getArr().dataType(); } } @@ -1518,26 +1511,59 @@ public class SDVariable implements Serializable { /** * Add a control dependency for this variable on the specified variable.
- * Control depnedencies can be used to enforce the execution order. + * Control dependencies can be used to enforce the execution order. * For example, if a control dependency X->Y exists, then Y will only be executed after X is executed - even * if Y wouldn't normally depend on the result/values of X. * * @param controlDependency Control dependency to add for this variable */ public void addControlDependency(SDVariable controlDependency){ - String cdN = controlDependency.getVarName(); - String n = this.getVarName(); - Variable v = sameDiff.getVariables().get(n); - if(v.getControlDeps() == null) - v.setControlDeps(new ArrayList()); - if(!v.getControlDeps().contains(cdN)) - v.getControlDeps().add(cdN); + Variable vThis = sameDiff.getVariables().get(getVarName()); + Variable vCD = sameDiff.getVariables().get(controlDependency.getVarName()); - Variable v2 = sameDiff.getVariables().get(cdN); - if(v2.getControlDepsForVar() == null) - v2.setControlDepsForVar(new ArrayList()); - if(!v2.getControlDepsForVar().contains(n)) - v2.getControlDepsForVar().add(n); + //If possible: add control dependency on ops + if(vThis.getOutputOfOp() != null && vCD.getOutputOfOp() != null ){ + //Op -> Op case + SameDiffOp oThis = sameDiff.getOps().get(vThis.getOutputOfOp()); + SameDiffOp oCD = sameDiff.getOps().get(vCD.getOutputOfOp()); + + if(oThis.getControlDeps() == null) + oThis.setControlDeps(new ArrayList()); + if(!oThis.getControlDeps().contains(oCD.getName())) + oThis.getControlDeps().add(oCD.getName()); + + if(oCD.getControlDepFor() == null) + oCD.setControlDepFor(new ArrayList()); + if(!oCD.getControlDepFor().contains(oThis.getName())) + oCD.getControlDepFor().add(oThis.getName()); + } else { + if(vThis.getOutputOfOp() != null){ + //const/ph -> op case + SameDiffOp oThis = sameDiff.getOps().get(vThis.getOutputOfOp()); + + if(oThis.getVarControlDeps() == null) + oThis.setVarControlDeps(new ArrayList()); + + if(!oThis.getVarControlDeps().contains(vCD.getName())) + oThis.getVarControlDeps().add(vCD.getName()); + + if(vCD.getControlDepsForOp() == null) + vCD.setControlDepsForOp(new ArrayList()); + if(!vCD.getControlDepsForOp().contains(oThis.getName())) + vCD.getControlDepsForOp().add(oThis.getName()); + } else { + //const/ph -> const/ph case + if(vThis.getControlDeps() == null) + vThis.setControlDeps(new ArrayList()); + if(!vThis.getControlDeps().contains(vCD.getName())) + vThis.getControlDeps().add(vCD.getName()); + + if(vCD.getControlDepsForVar() == null) + vCD.setControlDepsForVar(new ArrayList()); + if(!vCD.getControlDepsForVar().contains(vThis.getName())) + vCD.getControlDepsForVar().add(vThis.getName()); + } + } } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index ddd9ecbb2..1bcb3aedb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -16,58 +16,16 @@ package org.nd4j.autodiff.samediff; -import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; - import com.google.flatbuffers.FlatBufferBuilder; -import java.io.BufferedInputStream; -import java.io.BufferedOutputStream; -import java.io.DataOutputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.lang.reflect.Method; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.IdentityHashMap; -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Queue; -import java.util.Set; -import java.util.Stack; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Getter; -import lombok.NonNull; -import lombok.Setter; +import lombok.*; import lombok.extern.slf4j.Slf4j; -import lombok.val; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunctionFactory; -import org.nd4j.autodiff.listeners.At; -import org.nd4j.autodiff.listeners.Listener; -import org.nd4j.autodiff.listeners.ListenerResponse; -import org.nd4j.autodiff.listeners.Loss; -import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.listeners.*; import org.nd4j.autodiff.listeners.impl.HistoryListener; import org.nd4j.autodiff.listeners.records.History; import org.nd4j.autodiff.listeners.records.LossCurve; @@ -75,34 +33,14 @@ import org.nd4j.autodiff.samediff.config.BatchOutputConfig; import org.nd4j.autodiff.samediff.config.EvaluationConfig; import org.nd4j.autodiff.samediff.config.FitConfig; import org.nd4j.autodiff.samediff.config.OutputConfig; -import org.nd4j.autodiff.samediff.internal.AbstractSession; -import org.nd4j.autodiff.samediff.internal.DataTypesSession; -import org.nd4j.autodiff.samediff.internal.InferenceSession; -import org.nd4j.autodiff.samediff.internal.SameDiffOp; -import org.nd4j.autodiff.samediff.internal.Variable; -import org.nd4j.autodiff.samediff.ops.SDBaseOps; -import org.nd4j.autodiff.samediff.ops.SDBitwise; -import org.nd4j.autodiff.samediff.ops.SDCNN; -import org.nd4j.autodiff.samediff.ops.SDImage; -import org.nd4j.autodiff.samediff.ops.SDLoss; -import org.nd4j.autodiff.samediff.ops.SDMath; -import org.nd4j.autodiff.samediff.ops.SDNN; -import org.nd4j.autodiff.samediff.ops.SDRNN; -import org.nd4j.autodiff.samediff.ops.SDRandom; +import org.nd4j.autodiff.samediff.internal.*; +import org.nd4j.autodiff.samediff.ops.*; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.ROC; -import org.nd4j.graph.ExecutionMode; -import org.nd4j.graph.FlatArray; -import org.nd4j.graph.FlatConfiguration; -import org.nd4j.graph.FlatGraph; -import org.nd4j.graph.FlatNode; -import org.nd4j.graph.FlatVariable; -import org.nd4j.graph.IntPair; -import org.nd4j.graph.OpType; -import org.nd4j.graph.UpdaterState; +import org.nd4j.graph.*; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -112,8 +50,6 @@ import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.api.ops.impl.controlflow.If; -import org.nd4j.linalg.api.ops.impl.controlflow.While; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; @@ -136,7 +72,6 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.GradientUpdater; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.primitives.AtomicBoolean; -import org.nd4j.linalg.primitives.AtomicDouble; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.DeviceLocalNDArray; @@ -152,6 +87,17 @@ import org.nd4j.weightinit.impl.NDArraySupplierInitScheme; import org.nd4j.weightinit.impl.ZeroInitScheme; import org.tensorflow.framework.GraphDef; +import java.io.*; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; + /** * SameDiff is the entrypoint for ND4J's automatic differentiation functionality. *

@@ -683,7 +629,7 @@ public class SameDiff extends SDBaseOps { for (val var : variables()) { SDVariable clone = var.clone(this); SDVariable newVar = sameDiff.var(clone); - if (var.getArr() != null && var.getVariableType() != VariableType.ARRAY) { //ARRAY type = "activations" - are overwritten anyway + if (var.getVariableType() != VariableType.ARRAY && var.getArr() != null ) { //ARRAY type = "activations" - are overwritten anyway sameDiff.associateArrayWithVariable(var.getArr(), newVar); } @@ -795,9 +741,9 @@ public class SameDiff extends SDBaseOps { * @param function the function to get the inputs for * @return the input ids for a given function */ - public String[] getInputsForOp(DifferentialFunction function) { + public String[] getInputsForOp(@NonNull DifferentialFunction function) { if (!ops.containsKey(function.getOwnName())) - throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName()); + throw new ND4JIllegalStateException("Unknown function instance id found: \"" + function.getOwnName() + "\""); List inputs = ops.get(function.getOwnName()).getInputsToOp(); return inputs == null ? null : inputs.toArray(new String[inputs.size()]); } @@ -1102,12 +1048,8 @@ public class SameDiff extends SDBaseOps { constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true)); break; case ARRAY: - // FIXME: remove this before release - val session = sessions.get(Thread.currentThread().getId()); - val varId = session.newVarId(variable.getVarName(), AbstractSession.OUTER_FRAME, 0, null); - session.getNodeOutputs().put(varId, arr); - //throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY"); - break; + throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for" + + " this type of variable is calculated "); case PLACEHOLDER: //Validate placeholder shapes: long[] phShape = variable.placeholderShape(); @@ -2152,11 +2094,32 @@ public class SameDiff extends SDBaseOps { requiredVars.addAll(l.requiredVariables(this).trainingVariables()); } - ArrayList listenersWitHistory = new ArrayList<>(listeners); + List listenersWitHistory = new ArrayList<>(listeners); + for(Listener l : this.listeners){ + if(!listenersWitHistory.contains(l)) + listenersWitHistory.add(l); + } listenersWitHistory.add(history); - for (int i = 0; i < numEpochs; i++) { + SameDiff gradInstance = getFunction("grad"); + if(gradInstance == null){ + createGradFunction(); + gradInstance = getFunction("grad"); + } + TrainingSession ts = new TrainingSession(gradInstance); + gradInstance.setTrainingConfig(trainingConfig); //In case any listeners want to use it + + Set paramsToTrain = new LinkedHashSet<>(); + for(Variable v : variables.values()){ + if(v.getVariable().getVariableType() == VariableType.VARIABLE){ + //TODO not all variable type are needed - i.e., variable that doesn't impact loss should be skipped + paramsToTrain.add(v.getName()); + } + } + + Loss lastLoss = null; + for (int i = 0; i < numEpochs; i++) { if (incrementEpochCount && hasListeners) { at.setEpoch(trainingConfig.getEpochCount()); for (Listener l : activeListeners) { @@ -2200,153 +2163,38 @@ public class SameDiff extends SDBaseOps { Preconditions.checkState(placeholders.size() > 0, "No placeholder variables were set for training"); resolveVariablesWith(placeholders); - //Calculate gradients: - execBackwards(placeholders, at.operation(), ds, requiredVars, activeListeners); - - - //Apply updater: + //Call TrainingSession to perform training if (!initializedTraining) initializeTraining(); - Map, AtomicDouble> regScore = null; //Holds regularization scores for later reporting to listeners - if (hasListeners) { - regScore = new HashMap<>(); - } + lastLoss = ts.trainingIteration( + trainingConfig, + placeholders, + paramsToTrain, + updaterMap, + ds, + getLossVariables(), + listenersWitHistory, + at); - int iteration = trainingConfig.getIterationCount(); - int e = trainingConfig.getEpochCount(); - for (Variable v : variables.values()) { - //Only update trainable params - float type parameters (variable type vars) - SDVariable sdv = v.getVariable(); - if (sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType()) - continue; - - - INDArray param = sdv.getArr(); - SDVariable gradVar = sdv.getGradient(); - if (gradVar == null) { - //Not all trainable parameters have gradients defined. - //Consider graph: in1->loss1; in2->loss2, where we optimize only loss1. - //No gradient will be present for in2, because in2 doesn't impact loss1 at all - continue; - } - INDArray grad = gradVar.getArr(); - //Note: don't need to divide by minibatch - that should be handled in loss function and hence loss function gradients, - // which should flow through to here - - //Pre-apply regularization (L1, L2) - List r = trainingConfig.getRegularization(); - int iterCount = trainingConfig.getIterationCount(); - int epochCount = trainingConfig.getEpochCount(); - double lr = trainingConfig.getUpdater().hasLearningRate() ? trainingConfig.getUpdater().getLearningRate(iteration, epochCount) : 1.0; - if (r != null && r.size() > 0) { - for (Regularization reg : r) { - if (reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER) { - reg.apply(param, grad, lr, iterCount, epochCount); - } - } - } - - //Apply updater. Note that we need to reshape to [1,length] for updater - INDArray reshapedView = Shape.newShapeNoCopy(grad, new long[]{1, grad.length()}, grad.ordering() == 'f'); //TODO make sure we always reshape in same order! - Preconditions.checkState(reshapedView != null, "Error reshaping array for parameter \"%s\": array is a view?", sdv); - GradientUpdater u = updaterMap.get(sdv.getVarName()); - try { - u.applyUpdater(reshapedView, iteration, e); - } catch (Throwable t) { - throw new RuntimeException("Error applying updater " + u.getClass().getSimpleName() + " to parameter \"" + sdv.getVarName() - + "\": either parameter size is inconsistent between iterations, or \"" + sdv.getVarName() + "\" should not be a trainable parameter?", t); - } - - //Post-apply regularization (weight decay) - if (r != null && r.size() > 0) { - for (Regularization reg : r) { - if (reg.applyStep() == Regularization.ApplyStep.POST_UPDATER) { - reg.apply(param, grad, lr, iterCount, epochCount); - if (hasListeners) { - double score = reg.score(param, iterCount, epochCount); - if (!regScore.containsKey(reg.getClass())) { - regScore.put(reg.getClass(), new AtomicDouble()); - } - regScore.get(reg.getClass()).addAndGet(score); - } - } - } - } - - if (hasListeners) { - for (Listener l : activeListeners) { - if (l.isActive(at.operation())) - l.preUpdate(this, at, v, reshapedView); - } - } - - - if (trainingConfig.isMinimize()) { - param.subi(grad); - } else { - param.addi(grad); - } - } - - double[] d = new double[lossVariables.size() + regScore.size()]; - List lossVars; - if (regScore.size() > 0) { - lossVars = new ArrayList<>(lossVariables.size() + regScore.size()); - lossVars.addAll(lossVariables); - int s = regScore.size(); - //Collect regularization losses - for (Map.Entry, AtomicDouble> entry : regScore.entrySet()) { - lossVars.add(entry.getKey().getSimpleName()); - d[s] = entry.getValue().get(); - } - } else { - lossVars = lossVariables; - } - - //Collect the losses... - SameDiff gradFn = sameDiffFunctionInstances.get(GRAD_FN_KEY); - int count = 0; - for (String s : lossVariables) { - INDArray arr = gradFn.getArrForVarName(s); - double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue(); - d[count++] = l; - } - - Loss loss = new Loss(lossVars, d); - - if (lossNames == null) { - lossNames = lossVars; - } else { - Preconditions.checkState(lossNames.equals(lossVars), - "Loss names mismatch, expected: %s, got: %s", lossNames, lossVars); - } if (lossSums == null) { - lossSums = d; + lossSums = lastLoss.getLosses().clone(); } else { - Preconditions.checkState(lossNames.equals(lossVars), - "Loss size mismatch, expected: %s, got: %s", lossSums.length, d.length); - for (int j = 0; j < lossSums.length; j++) { - lossSums[j] += d[j]; + lossSums[j] += lastLoss.getLosses()[j]; } } lossCount++; - if (hasListeners) { - for (Listener l : activeListeners) { - l.iterationDone(this, at, ds, loss); - } - - } - trainingConfig.incrementIterationCount(); } long epochTime = System.currentTimeMillis() - epochStartTime; if (incrementEpochCount) { + lossNames = lastLoss.getLossNames(); + for (int j = 0; j < lossSums.length; j++) lossSums[j] /= lossCount; @@ -2356,14 +2204,13 @@ public class SameDiff extends SDBaseOps { lossCurve = new LossCurve(lossSums, lossNames); } + if (incrementEpochCount) { if (hasListeners) { - boolean doStop = false; Listener stopped = null; for (Listener l : activeListeners) { - ListenerResponse res = l.epochEnd(this, at, lossCurve, epochTime); if (res == ListenerResponse.STOP && (i < numEpochs - 1)) { @@ -2431,7 +2278,6 @@ public class SameDiff extends SDBaseOps { trainingConfig.incrementEpochCount(); } - if (i < numEpochs - 1) { iter.reset(); } @@ -2507,7 +2353,9 @@ public class SameDiff extends SDBaseOps { INDArray arr = v.getVariable().getArr(); long stateSize = trainingConfig.getUpdater().stateSize(arr.length()); INDArray view = stateSize == 0 ? null : Nd4j.createUninitialized(arr.dataType(), 1, stateSize); - updaterMap.put(v.getName(), trainingConfig.getUpdater().instantiate(view, true)); + GradientUpdater gu = trainingConfig.getUpdater().instantiate(view, false); + gu.setStateViewArray(view, arr.shape(), arr.ordering(), true); + updaterMap.put(v.getName(), gu); } initializedTraining = true; @@ -3862,7 +3710,8 @@ public class SameDiff extends SDBaseOps { long thisSize = trainingConfig.getUpdater().stateSize(arr.length()); if (thisSize > 0) { INDArray stateArr = Nd4j.create(arr.dataType(), 1, thisSize); - GradientUpdater u = trainingConfig.getUpdater().instantiate(stateArr, true); + GradientUpdater u = trainingConfig.getUpdater().instantiate(stateArr, false); + u.setStateViewArray(stateArr, arr.shape(), arr.ordering(), true); //TODO eventually this should be 1 call... updaterMap.put(v.getVarName(), u); } else { GradientUpdater u = trainingConfig.getUpdater().instantiate((INDArray) null, true); @@ -3946,7 +3795,53 @@ public class SameDiff extends SDBaseOps { sessions.clear(); //Recalculate datatypes of outputs, and dynamically update them - calculateOutputDataTypes(true); + Set allSeenOps = new HashSet<>(); + Queue queueOps = new LinkedList<>(); + + for(String s : dataTypeMap.keySet()){ + Variable v = variables.get(s); + v.getVariable().setDataType(dataTypeMap.get(s)); + List inToOp = v.getInputsForOp(); + if(inToOp != null){ + for(String op : inToOp) { + if (!allSeenOps.contains(op)) { + allSeenOps.add(op); + queueOps.add(op); + } + } + } + } + + while(!queueOps.isEmpty()){ + String op = queueOps.remove(); + SameDiffOp o = ops.get(op); + List inVars = o.getInputsToOp(); + List inDTypes = new ArrayList<>(); + if(inVars != null) { + for (String s : inVars) { + SDVariable v = variables.get(s).getVariable(); + inDTypes.add(v.dataType()); + } + } + List outDtypes = o.getOp().calculateOutputDataTypes(inDTypes); + List outVars = o.getOutputsOfOp(); + for( int i=0; i 0 && function.args()[0].getArr() != null) { //Args may be null or length 0 for some ops, like eye - ordering = function.args()[0].getArr().ordering(); - } if (checkGet == null) { //Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0); @@ -4530,45 +4423,6 @@ public class SameDiff extends SDBaseOps { return sameDiffFunctionInstances.get(functionName); } - - /** - * @deprecated Use {@link SDBaseOps#whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} - */ - @Deprecated - public While whileStatement(SameDiffConditional sameDiffConditional, - SameDiffFunctionDefinition conditionBody, - SameDiffFunctionDefinition loopBody - , SDVariable[] inputVars) { - return While.builder() - .inputVars(inputVars) - .condition(conditionBody) - .predicate(sameDiffConditional) - .trueBody(loopBody) - .parent(this) - .blockName("while-" + UUID.randomUUID().toString()) - .build(); - } - - /** - * @deprecated Use {@link SDBaseOps#ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} - */ - @Deprecated - public If ifStatement(SameDiffConditional conditional, - SameDiffFunctionDefinition conditionBody, - SameDiffFunctionDefinition trueBody, - SameDiffFunctionDefinition falseBody - , SDVariable[] inputVars) { - return If.builder() - .conditionBody(conditionBody) - .falseBody(falseBody) - .trueBody(trueBody) - .predicate(conditional) - .inputVars(inputVars) - .parent(this) - .blockName("if-" + UUID.randomUUID().toString()) - .build(); - } - /** * Create a new TensorArray. */ @@ -4648,6 +4502,51 @@ public class SameDiff extends SDBaseOps { return execSingle(placeholders, outputs.get(0)); } + /** + * See {@link #calculateGradients(Map, Collection)} + */ + public Map calculateGradients(Map placeholderVals, @NonNull String... variables) { + Preconditions.checkArgument(variables.length > 0, "No variables were specified"); + return calculateGradients(placeholderVals, Arrays.asList(variables)); + } + + /** + * Calculate and return the gradients for the specified variables + * + * @param placeholderVals Placeholders. May be null + * @param variables Names of the variables that you want the gradient arrays for + * @return Gradients as a map, keyed by the variable name + */ + public Map calculateGradients(Map placeholderVals, @NonNull Collection variables) { + Preconditions.checkArgument(!variables.isEmpty(), "No variables were specified"); + if (getFunction(GRAD_FN_KEY) == null) { + createGradFunction(); + } + + List gradVarNames = new ArrayList<>(variables.size()); + for (String s : variables) { + Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", s); + SDVariable v = getVariable(s).getGradient(); + if (v != null) { + //In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable + gradVarNames.add(v.getVarName()); + } + } + + //Key is gradient variable name + Map grads = getFunction(GRAD_FN_KEY).output(placeholderVals, gradVarNames); + + Map out = new HashMap<>(); + for (String s : variables) { + if (getVariable(s).getGradient() != null) { + String gradVar = getVariable(s).getGradient().getVarName(); + out.put(s, grads.get(gradVar)); + } + } + + return out; + } + /** * Create (if required) and then calculate the variable gradients (backward pass) for this graph.
* After execution, the gradient arrays can be accessed using {@code myVariable.getGradient().getArr()}
@@ -4660,6 +4559,7 @@ public class SameDiff extends SDBaseOps { * * @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map */ + @Deprecated public void execBackwards(Map placeholders, Operation op) { execBackwards(placeholders, op, null, Collections.emptyList(), Collections.emptyList()); } @@ -4669,10 +4569,12 @@ public class SameDiff extends SDBaseOps { *

* Uses {@link Operation#INFERENCE}. */ + @Deprecated public void execBackwards(Map placeholders) { execBackwards(placeholders, Operation.INFERENCE); } + @Deprecated protected void execBackwards(Map placeholders, Operation op, MultiDataSet batch, Collection requiredActivations, List activeListeners) { if (getFunction(GRAD_FN_KEY) == null) { createGradFunction(); @@ -4709,6 +4611,7 @@ public class SameDiff extends SDBaseOps { /** * See {@link #execBackwards(Map, List, Operation)} */ + @Deprecated public Map execBackwards(Map placeholders, Operation op, String... variableGradNamesList) { return execBackwards(placeholders, Arrays.asList(variableGradNamesList), op, null, Collections.emptyList(), Collections.emptyList()); } @@ -4718,6 +4621,7 @@ public class SameDiff extends SDBaseOps { *

* Uses {@link Operation#INFERENCE}. */ + @Deprecated public Map execBackwards(Map placeholders, String... variableGradNamesList) { return execBackwards(placeholders, Operation.INFERENCE, variableGradNamesList); } @@ -4730,6 +4634,7 @@ public class SameDiff extends SDBaseOps { * @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map * @param variableGradNamesList Names of the gradient variables to calculate */ + @Deprecated public Map execBackwards(Map placeholders, List variableGradNamesList, Operation operation) { return execBackwards(placeholders, variableGradNamesList, operation, null, Collections.emptyList(), Collections.emptyList()); } @@ -4739,10 +4644,12 @@ public class SameDiff extends SDBaseOps { *

* Uses {@link Operation#INFERENCE}. */ + @Deprecated public Map execBackwards(Map placeholders, List variableGradNamesList) { return execBackwards(placeholders, variableGradNamesList, Operation.INFERENCE); } + @Deprecated protected Map execBackwards(Map placeholders, List variableGradNamesList, Operation operation, MultiDataSet batch, Collection requiredActivations, List activeListeners) { if (getFunction(GRAD_FN_KEY) == null) { @@ -5462,7 +5369,7 @@ public class SameDiff extends SDBaseOps { 0, 0, -1, - 0, 0, 0, 0, 0, 0); + 0, 0, 0, 0, 0, 0, 0, 0, 0); return flatNode; } @@ -5538,7 +5445,7 @@ public class SameDiff extends SDBaseOps { val idxForOps = new IdentityHashMap(); List allVars = variables(); for (SDVariable variable : allVars) { - INDArray arr = variable.getArr(); + INDArray arr = variable.getVariableType() == VariableType.ARRAY ? null : variable.getArr(); log.trace("Exporting variable: [{}]", variable.getVarName()); //If variable is the output of some op - let's use the ONE index for exporting, and properly track the output @@ -5582,7 +5489,26 @@ public class SameDiff extends SDBaseOps { shape = FlatVariable.createShapeVector(bufferBuilder, shp); } - int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(variable.dataType()), shape, array, -1, varType); + int controlDeps = 0; + int controlDepsForOp = 0; + int controlDepsForVar = 0; + Variable v = variables.get(varName); + + int[] cds = FlatBuffersMapper.mapOrNull(v.getControlDeps(), bufferBuilder); + if(cds != null) + controlDeps = FlatVariable.createControlDepsVector(bufferBuilder, cds); + + int[] cdsForOp = FlatBuffersMapper.mapOrNull(v.getControlDepsForOp(), bufferBuilder); + if(cdsForOp != null) + controlDepsForOp = FlatVariable.createControlDepForOpVector(bufferBuilder, cdsForOp); + + int[] cdsForVar = FlatBuffersMapper.mapOrNull(v.getControlDepsForVar(), bufferBuilder); + if(cdsForVar != null) + controlDepsForVar = FlatVariable.createControlDepsForVarVector(bufferBuilder, cdsForVar); + + + int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(variable.dataType()), shape, + array, -1, varType, controlDeps, controlDepsForOp, controlDepsForVar); flatVariables.add(flatVariable); } @@ -5593,43 +5519,6 @@ public class SameDiff extends SDBaseOps { flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, variableList, reverseMap, forwardMap, framesMap, idCounter, fnId)); } - // we're dumping scopes now - for (Map.Entry scope : sameDiffFunctionInstances.entrySet()) { - if (scope.getKey().equalsIgnoreCase(GRAD_FN_KEY)) { - //Skip the gradient function for export - continue; - } - - flatNodes.add(asFlatNode(scope.getKey(), scope.getValue(), bufferBuilder)); - val currVarList = new ArrayList(scope.getValue().variables()); - // converting all ops from node - for (val node : scope.getValue().variables()) { - INDArray arr = node.getArr(); - if (arr == null) { - continue; - } - - int name = bufferBuilder.createString(node.getVarName()); - int array = arr.toFlatArray(bufferBuilder); - int id = IntPair.createIntPair(bufferBuilder, ++idx, 0); - - val pair = parseVariable(node.getVarName()); - reverseMap.put(pair.getFirst(), idx); - - log.trace("Adding [{}] as [{}]", pair.getFirst(), idx); - - byte varType = (byte) node.getVariableType().ordinal(); - int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(arr.dataType()), 0, array, -1, varType); - flatVariables.add(flatVariable); - } - - //add functions - for (SameDiffOp op : scope.getValue().ops.values()) { - DifferentialFunction func = op.getOp(); - flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, currVarList, reverseMap, forwardMap, framesMap, idCounter, null)); - } - } - int outputsOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatOffsets)); int variablesOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatVariables)); int nodesOffset = FlatGraph.createNodesVector(bufferBuilder, Ints.toArray(flatNodes)); @@ -5958,7 +5847,7 @@ public class SameDiff extends SDBaseOps { vars.add(fg.variables(i)); } - FlatConfiguration conf = fg.configuration(); +// FlatConfiguration conf = fg.configuration(); /* Reconstruct the graph We'll do the reconstruction manually here, rather than using sd.var(...), so that we have more control @@ -5995,6 +5884,35 @@ public class SameDiff extends SDBaseOps { SDVariable var = new SDVariable(n, vt, sd, shape, dtype, null); sd.variables.put(n, Variable.builder().name(n).variable(var).build()); sd.variableNameToShape.put(n, shape); + Variable v2 = sd.variables.get(n); + + //Reconstruct control dependencies + if(v.controlDepsLength() > 0){ + int num = v.controlDepsLength(); + List l = new ArrayList<>(num); + for( int i=0; i 0){ + int num = v.controlDepForOpLength(); + List l = new ArrayList<>(num); + for( int i=0; i 0){ + int num = v.controlDepsForVarLength(); + List l = new ArrayList<>(num); + for( int i=0; i 0) { + int l = fn.controlDepsLength(); + List list = new ArrayList<>(l); + for( int i=0; i 0) { + int l = fn.varControlDepsLength(); + List list = new ArrayList<>(l); + for( int i=0; i 0) { + int l = fn.controlDepForLength(); + List list = new ArrayList<>(l); + for( int i=0; i()); } if (!v.getInputsForOp().contains(df.getOwnName())) { - v.getInputsForOp( - - ).add(df.getOwnName()); + v.getInputsForOp().add(df.getOwnName()); } } @@ -6414,32 +6360,6 @@ public class SameDiff extends SDBaseOps { return sb.toString(); } - /** - * Calculate data types for the variables in the graph - */ - public Map calculateOutputDataTypes() { - return calculateOutputDataTypes(false); - } - - /** - * Calculate data types for the variables in the graph - */ - public Map calculateOutputDataTypes(boolean dynamicUpdate) { - List allVars = new ArrayList<>(variables.keySet()); - DataTypesSession session = new DataTypesSession(this, dynamicUpdate); - Map phValues = new HashMap<>(); - for (Variable v : variables.values()) { - if (v.getVariable().isPlaceHolder()) { - org.nd4j.linalg.api.buffer.DataType dt = v.getVariable().dataType(); - Preconditions.checkNotNull(dt, "Placeholder variable %s has null datatype", v.getName()); - phValues.put(v.getName(), dt); - } - } - Map out = session.output(allVars, phValues, null, - Collections.emptyList(), Collections.emptyList(), At.defaultAt(Operation.INFERENCE)); - return out; - } - /** * For internal use only. * Creates a new discinct block name from baseName. @@ -6470,14 +6390,14 @@ public class SameDiff extends SDBaseOps { * @return The imported graph */ public static SameDiff importFrozenTF(File graphFile) { - return TFGraphMapper.getInstance().importGraph(graphFile); + return TFGraphMapper.importGraph(graphFile); } /** * See {@link #importFrozenTF(File)} */ public static SameDiff importFrozenTF(GraphDef graphDef) { - return TFGraphMapper.getInstance().importGraph(graphDef); + return TFGraphMapper.importGraph(graphDef); } @@ -6487,7 +6407,7 @@ public class SameDiff extends SDBaseOps { * Again, the input can be text or binary. */ public static SameDiff importFrozenTF(InputStream graph) { - return TFGraphMapper.getInstance().importGraph(graph); + return TFGraphMapper.importGraph(graph); } @@ -6511,7 +6431,7 @@ public class SameDiff extends SDBaseOps { int start = 1; // if we already have a name like "op_2", start from trying "op_3" - if (base.contains("_")) { + if (base.contains("_") && base.matches(".*_\\d+")) { // extract number used to generate base Matcher num = Pattern.compile("(.*)_(\\d+)").matcher(base); // extract argIndex used to generate base diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java new file mode 100644 index 000000000..776d26794 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java @@ -0,0 +1,444 @@ +package org.nd4j.autodiff.samediff.internal; + +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.function.Predicate; +import org.nd4j.linalg.primitives.Pair; + +import java.util.*; + +/** + * Object dependency tracker. + *
+ * Dependency are denoted by: X -> Y, which means "Y depends on X"
+ * In this implementation:
+ * - Dependencies may be satisfied, or not satisfied
+ * - The implementation tracks when the dependency for an object Y are fully satisfied. This occurs when:
+ * 1. No dependencies X->Y exist
+ * 2. All dependencies of the form X->Y have been marked as satisfied, via markSatisfied(x)
+ * - When a dependency is satisfied, any dependent (Ys) are checked to see if all their dependencies are satisfied
+ * - If a dependent has all dependencies satisfied, it is added to the "new all satisfied" queue for processing, + * which can be accessed via {@link #hasNewAllSatisfied()}, {@link #getNewAllSatisfied()} and {@link #getNewAllSatisfiedList()}
+ *
+ * Note: Two types of dependencies exist
+ * 1. Standard dependencies - i.e., "Y depends on X"
+ * 2. "Or" dependencies - i.e., "Y depends on (A or B)".
+ * For Or dependencies of the form "(A or B) -> Y", Y will be marked as "all dependencies satisfied" if either A or B is marked as satisfied. + * + * @param For a dependency X -> Y, Y has type T + * @param For a dependency X -> Y, X has type D + */ +@Slf4j +public abstract class AbstractDependencyTracker { + @Getter + private final Map> dependencies; //Key: the dependent. Value: all things that the key depends on + @Getter + private final Map>> orDependencies; //Key: the dependent. Value: the set of OR dependencies + private final Map> reverseDependencies = new HashMap<>(); //Key: the dependee. Value: The set of all dependents that depend on this value + private final Map> reverseOrDependencies = new HashMap<>(); + private final Set satisfiedDependencies = new HashSet<>(); //Mark the dependency as satisfied. If not in set: assumed to not be satisfied + + private final Set allSatisfied; //Set of all dependent values (Ys) that have all dependencies satisfied + private final Queue allSatisfiedQueue = new LinkedList<>(); //Queue for *new* "all satisfied" values. Values are removed using the "new all satisfied" methods + + + protected AbstractDependencyTracker() { + dependencies = (Map>) newTMap(); + orDependencies = (Map>>) newTMap(); + allSatisfied = newTSet(); + } + + /** + * @return A new map where the dependents (i.e., Y in "X -> Y") are the key + */ + protected abstract Map newTMap(); + + /** + * @return A new set where the dependents (i.e., Y in "X -> Y") are the key + */ + protected abstract Set newTSet(); + + /** + * @return A String representation of the dependent object + */ + protected abstract String toStringT(T t); + + /** + * @return A String representation of the dependee object + */ + protected abstract String toStringD(D d); + + /** + * Clear all internal state for the dependency tracker + */ + public void clear() { + dependencies.clear(); + orDependencies.clear(); + reverseDependencies.clear(); + reverseOrDependencies.clear(); + satisfiedDependencies.clear(); + allSatisfied.clear(); + allSatisfiedQueue.clear(); + } + + /** + * @return True if no dependencies have been defined + */ + public boolean isEmpty() { + return dependencies.isEmpty() && orDependencies.isEmpty() && + allSatisfiedQueue.isEmpty(); + } + + /** + * @return True if the dependency has been marked as satisfied using {@link #markSatisfied(Object, boolean)} + */ + public boolean isSatisfied(@NonNull D x) { + return satisfiedDependencies.contains(x); + } + + /** + * Mark the specified value as satisfied. + * For example, if two dependencies have been previously added (X -> Y) and (X -> A) then after the markSatisfied(X, true) + * call, both of these dependencies are considered satisfied. + * + * @param x Value to mark + * @param satisfied Whether to mark as satisfied (true) or unsatisfied (false) + */ + public void markSatisfied(@NonNull D x, boolean satisfied) { + if (satisfied) { + boolean alreadySatisfied = satisfiedDependencies.contains(x); + + if (!alreadySatisfied) { + satisfiedDependencies.add(x); + + //Check if any Y's exist that have dependencies that are all satisfied, for X -> Y + Set s = reverseDependencies.get(x); + Set s2 = reverseOrDependencies.get(x); + + Set set; + if (s != null && s2 != null) { + set = newTSet(); + set.addAll(s); + set.addAll(s2); + } else if (s != null) { + set = s; + } else if (s2 != null) { + set = s2; + } else { + if (log.isTraceEnabled()) { + log.trace("No values depend on: {}", toStringD(x)); + } + return; + } + + for (T t : set) { + Set required = dependencies.get(t); + Set> requiredOr = orDependencies.get(t); + boolean allSatisfied = true; + if (required != null) { + for (D d : required) { + if (!isSatisfied(d)) { + allSatisfied = false; + break; + } + } + } + if (allSatisfied && requiredOr != null) { + for (Pair p : requiredOr) { + if (!isSatisfied(p.getFirst()) && !isSatisfied(p.getSecond())) { + allSatisfied = false; + break; + } + } + } + + if (allSatisfied) { + if (!this.allSatisfied.contains(t)) { + this.allSatisfied.add(t); + this.allSatisfiedQueue.add(t); + } + } + } + } + + } else { + satisfiedDependencies.remove(x); + if (!allSatisfied.isEmpty()) { + + Set reverse = reverseDependencies.get(x); + if (reverse != null) { + for (T y : reverse) { + if (allSatisfied.contains(y)) { + allSatisfied.remove(y); + allSatisfiedQueue.remove(y); + } + } + } + Set orReverse = reverseOrDependencies.get(x); + if (orReverse != null) { + for (T y : orReverse) { + if (allSatisfied.contains(y) && !isAllSatisfied(y)) { + allSatisfied.remove(y); + allSatisfiedQueue.remove(y); + } + } + } + } + } + } + + /** + * Check whether any dependencies x -> y exist, for y (i.e., anything previously added by {@link #addDependency(Object, Object)} + * or {@link #addOrDependency(Object, Object, Object)} + * + * @param y Dependent to check + * @return True if Y depends on any values + */ + public boolean hasDependency(@NonNull T y) { + Set s1 = dependencies.get(y); + if (s1 != null && !s1.isEmpty()) + return true; + + Set> s2 = orDependencies.get(y); + return s2 != null && !s2.isEmpty(); + } + + /** + * Get all dependencies x, for x -> y, and (x1 or x2) -> y + * + * @param y Dependent to get dependencies for + * @return List of dependencies + */ + public DependencyList getDependencies(@NonNull T y) { + Set s1 = dependencies.get(y); + Set> s2 = orDependencies.get(y); + + List l1 = (s1 == null ? null : new ArrayList<>(s1)); + List> l2 = (s2 == null ? null : new ArrayList<>(s2)); + + return new DependencyList<>(y, l1, l2); + } + + /** + * Add a dependency: y depends on x, as in x -> y + * + * @param y The dependent + * @param x The dependee that is required for Y + */ + public void addDependency(@NonNull T y, @NonNull D x) { + if (!dependencies.containsKey(y)) + dependencies.put(y, new HashSet()); + + if (!reverseDependencies.containsKey(x)) + reverseDependencies.put(x, newTSet()); + + dependencies.get(y).add(x); + reverseDependencies.get(x).add(y); + + checkAndUpdateIfAllSatisfied(y); + } + + protected void checkAndUpdateIfAllSatisfied(@NonNull T y) { + boolean allSat = isAllSatisfied(y); + if (allSat) { + //Case where "x is satisfied" happened before x->y added + if (!allSatisfied.contains(y)) { + allSatisfied.add(y); + allSatisfiedQueue.add(y); + } + } else if (allSatisfied.contains(y)) { + if (!allSatisfiedQueue.contains(y)) { + StringBuilder sb = new StringBuilder(); + sb.append("Dependent object \"").append(toStringT(y)).append("\" was previously processed after all dependencies") + .append(" were marked satisfied, but is now additional dependencies have been added.\n"); + DependencyList dl = getDependencies(y); + if (dl.getDependencies() != null) { + sb.append("Dependencies:\n"); + for (D d : dl.getDependencies()) { + sb.append(d).append(" - ").append(isSatisfied(d) ? "Satisfied" : "Not satisfied").append("\n"); + } + } + if (dl.getOrDependencies() != null) { + sb.append("Or dependencies:\n"); + for (Pair p : dl.getOrDependencies()) { + sb.append(p).append(" - satisfied=(").append(isSatisfied(p.getFirst())).append(",").append(isSatisfied(p.getSecond())).append(")"); + } + } + throw new IllegalStateException(sb.toString()); + } + + //Not satisfied, but is in the queue -> needs to be removed + allSatisfied.remove(y); + allSatisfiedQueue.remove(y); + } + } + + protected boolean isAllSatisfied(@NonNull T y) { + Set set1 = dependencies.get(y); + + boolean allSatisfied = true; + if (set1 != null) { + for (D d : set1) { + allSatisfied = isSatisfied(d); + if (!allSatisfied) + break; + } + } + if (allSatisfied) { + Set> set2 = orDependencies.get(y); + if (set2 != null) { + for (Pair p : set2) { + allSatisfied = isSatisfied(p.getFirst()) || isSatisfied(p.getSecond()); + if (!allSatisfied) + break; + } + } + } + return allSatisfied; + } + + + /** + * Remove a dependency (x -> y) + * + * @param y The dependent that currently requires X + * @param x The dependee that is no longer required for Y + */ + public void removeDependency(@NonNull T y, @NonNull D x) { + if (!dependencies.containsKey(y) && !orDependencies.containsKey(y)) + return; + + Set s = dependencies.get(y); + if (s != null) { + s.remove(x); + if (s.isEmpty()) + dependencies.remove(y); + } + + Set s2 = reverseDependencies.get(x); + if (s2 != null) { + s2.remove(y); + if (s2.isEmpty()) + reverseDependencies.remove(x); + } + + + Set> s3 = orDependencies.get(y); + if (s3 != null) { + boolean removedReverse = false; + Iterator> iter = s3.iterator(); + while (iter.hasNext()) { + Pair p = iter.next(); + if (x.equals(p.getFirst()) || x.equals(p.getSecond())) { + iter.remove(); + + if (!removedReverse) { + Set set1 = reverseOrDependencies.get(p.getFirst()); + Set set2 = reverseOrDependencies.get(p.getSecond()); + + set1.remove(y); + set2.remove(y); + + if (set1.isEmpty()) + reverseOrDependencies.remove(p.getFirst()); + if (set2.isEmpty()) + reverseOrDependencies.remove(p.getSecond()); + + removedReverse = true; + } + } + } + } + if (s3 != null && s3.isEmpty()) + orDependencies.remove(y); + } + + /** + * Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) -> Y
+ * If either x1 or x2 (or both) are marked satisfied via {@link #markSatisfied(Object, boolean)} then the + * dependency is considered satisfied + * + * @param y Dependent + * @param x1 Dependee 1 + * @param x2 Dependee 2 + */ + public void addOrDependency(@NonNull T y, @NonNull D x1, @NonNull D x2) { + if (!orDependencies.containsKey(y)) + orDependencies.put(y, new HashSet>()); + + if (!reverseOrDependencies.containsKey(x1)) + reverseOrDependencies.put(x1, newTSet()); + if (!reverseOrDependencies.containsKey(x2)) + reverseOrDependencies.put(x2, newTSet()); + + orDependencies.get(y).add(new Pair<>(x1, x2)); + reverseOrDependencies.get(x1).add(y); + reverseOrDependencies.get(x2).add(y); + + checkAndUpdateIfAllSatisfied(y); + } + + /** + * @return True if there are any new/unprocessed "all satisfied dependents" (Ys in X->Y) + */ + public boolean hasNewAllSatisfied() { + return !allSatisfiedQueue.isEmpty(); + } + + /** + * Returns the next new dependent (Y in X->Y) that has all dependees (Xs) marked as satisfied via {@link #markSatisfied(Object, boolean)} + * Throws an exception if {@link #hasNewAllSatisfied()} returns false.
+ * Note that once a value has been retrieved from here, no new dependencies of the form (X -> Y) can be added for this value; + * the value is considered "processed" at this point. + * + * @return The next new "all satisfied dependent" + */ + public T getNewAllSatisfied() { + Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied"); + return allSatisfiedQueue.remove(); + } + + /** + * @return As per {@link #getNewAllSatisfied()} but returns all values + */ + public List getNewAllSatisfiedList() { + Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied"); + List ret = new ArrayList<>(allSatisfiedQueue); + allSatisfiedQueue.clear(); + return ret; + } + + /** + * As per {@link #getNewAllSatisfied()} but instead of returning the first dependee, it returns the first that matches + * the provided predicate. If no value matches the predicate, null is returned + * + * @param predicate Predicate gor checking + * @return The first value matching the predicate, or null if no values match the predicate + */ + public T getFirstNewAllSatisfiedMatching(@NonNull Predicate predicate) { + Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied"); + + T t = allSatisfiedQueue.peek(); + if (predicate.test(t)) { + t = allSatisfiedQueue.remove(); + allSatisfied.remove(t); + return t; + } + + if (allSatisfiedQueue.size() > 1) { + Iterator iter = allSatisfiedQueue.iterator(); + while (iter.hasNext()) { + t = iter.next(); + if (predicate.test(t)) { + iter.remove(); + allSatisfied.remove(t); + return t; + } + } + } + + return null; //None match predicate + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index 387e25f48..cbdb39cd6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -16,46 +16,59 @@ package org.nd4j.autodiff.samediff.internal; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.Getter; -import lombok.NonNull; +import lombok.*; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.Listener; -import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.function.Predicate; import java.util.*; -import org.nd4j.linalg.dataset.api.MultiDataSet; -import org.nd4j.linalg.factory.Nd4j; /** - * Additional functionality to add: - * - Workspaces support - * - Proper cache support + * AbstractSession is a SameDiff graph execution class that inference and training it built upon + * It walks through the graph, dynamically executing operations that can be executed next, but (again, dynamically) only + * executing the subset of the graph that is actually required to get the requested outputs.
+ * None of what AbstractSession implements is NDArray-specific.
+ * Note that most of the implementation complexity comes from dynamic graphs - i.e., nested loops, control ops, etc * * @param Node output type - for example, INDArray, shape, etc depending on what we're calculating * @param Op type + * @author Alex Black */ @Slf4j public abstract class AbstractSession { - //All execution happens in a frame... this is the name of the main/outer frame + /** + * All execution in Samediff happens in a frame... this is the name of the main/outer frame - i.e., the "default" frame + * Other frames (such as for loops) may be nested within this frame + */ public static final String OUTER_FRAME = "main"; protected final SameDiff sameDiff; @Getter - protected final Map nodeOutputs = new HashMap<>(); + protected final Map nodeOutputs = new HashMap<>(); //Key: variable (at a given frame + iteration). Value: the calculated output for that variable @Getter - protected final Map> tensorArrays = new HashMap<>(); //Stores the outputs for a TensorArray ops - protected final Queue availableForExec = new LinkedList<>(); - protected final Set availableForExecSet = new HashSet<>(); //Same content as the queue, but used for O(1) contains instead of ordered removal + protected final Map> tensorArrays = new HashMap<>(); //Stores the underlying arrays for TensorArray ops + /* + The dependency tracker is responsible for determining what ops (at what frame/iteration) can be executed next, given + what has been executed so far. + For static graphs, such as abstraction would not be necessary; for dynamic graphs (i.e., nested loops, of arbitary + number of iterations and depth - and also switch ops which can cause whole subgraphs to not be executed) this is necessary + Note: the ExecStep represents one step for execution - some steps are as simple as "execute an op (at the given frame/iter)" + It works by adding dependencies (X -> Y - such as "op Y depends on the output of op X") and then marking them as + satisfied ("op X has been calculated"). Once all dependencies for an execution step have been satisfied, the execution step + is added to a queue - outputs of which can be accessed with dt.getNewAllSatisfied() and dt.getNewAllSatisfiedList(), + at which point it is removed from the dependency tracker + */ + protected final DependencyTracker dt = new DependencyTracker<>(); + /** * Contains variables we *might* need to execute in process of getting outputs we want. * Variables not in this set are definitely not needed to get the requested output variables, but variables that are @@ -63,45 +76,22 @@ public abstract class AbstractSession { */ protected final Set subgraph = new HashSet<>(); /** - * Stores what variables are required to calculate the specific variable. These inputs could be inputs to an op that - * calculates the variable's value, or it could be a control dependenci - * Keys: variable (in specific frame/iteration) to be executed - * Values: inputs to that node (inc. frame and iteration), unordered - needed for execution of op giving variable + * As per subgraph set, but for ops instead */ - protected final Map> execInputs = new HashMap<>(); + protected final Set subgraphOps = new HashSet<>(); /** - * As per execInputs map - with the different that the iteration number should be ignored (i.e., always 0) - * Reason: Enter nodes - these are executed once - * Example: EnterOp(x) -> LoopCondition(less(x,y)): less op requires "X" on all iterations which is the output of the - * enter op, which is only executed for iteration 0 in a frame. + * Constains the names of ops that don't have any inputs. Kept because normally ops are triggered for execution when + * their all their inputs have been calculated; we'll trigger that step manually during execution initialization */ - protected final Map> execInputsAllIter = new HashMap<>(); - - /** - * Contains the set set of constant and placeholders inputs - * Essentially the same as the execInputs map, but the constants and placeholders are used for calculating all instances - * of a variable - i.e., the input (constant/placeholder) applies to all frames and iterations. - * Keys: variable (any/all frame/iteration) to be executed - * Values: constant or placeholder needed for execution of op giving variable - */ - protected final Map> execConstInputs = new HashMap<>(); - /** - * Map for exit ops. This is used to determine where an exit op should exit to. - * Values added on enter ops. Note that it's not sufficient to - * Key: frame name (for enter/exit nodes). - * Value: parent frame name + iteration - */ - @Getter - protected final Map frameParents = new HashMap<>(); - + protected final Set zeroInputOpsInSubgraph = new HashSet<>(); public AbstractSession(@NonNull SameDiff sameDiff) { this.sameDiff = sameDiff; } - public boolean contains(String variable, String frame, int iteration, FrameIter parentFrameIter){ - VarId varId = newVarId(variable, frame, iteration, parentFrameIter); + public boolean contains(String variable, String frame, int iteration, FrameIter parentFrameIter) { + VarId varId = new VarId(variable, frame, iteration, parentFrameIter); return nodeOutputs.containsKey(varId); } @@ -114,62 +104,36 @@ public abstract class AbstractSession { /** * Get a previously calculated output + * * @param enforceExistence If true: throw an exception if the array does not exist */ public T get(String variable, String frame, int iteration, FrameIter parentFrameIter, boolean enforceExistence) { //TODO eventually we'll cache and reuse VarId objects here to avoid garbage generation on lookup etc - VarId varId = newVarId(variable, frame, iteration, parentFrameIter); + VarId varId = new VarId(variable, frame, iteration, parentFrameIter); T out = nodeOutputs.get(varId); - if(enforceExistence) { + if (enforceExistence) { Preconditions.checkNotNull(out, "No output found for variable %s (frame %s, iteration %s)", variable, frame, iteration); } return out; } - public VarId newVarId(String variable, String frame, int iteration, FrameIter parentFrameIter) { - //TODO eventually we'll cache and reuse VarId objects here to avoid garbage generation on lookup - return new VarId(variable, frame, iteration, parentFrameIter); - } - - public VarId newVarId(String variable, FrameIter frameIter) { - return newVarId(variable, frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame()); - } - /** - * @deprecated Use {@link #output(List, Map, MultiDataSet, Collection, List, At)}. + * Get the output of the session - i.e., perform inference/forward pass and return the autputs for the specified variables * - * @param training Uses Operation.TRAINING if true, otherwise Operation.INFERENCE - */ - @Deprecated - public Map output(@NonNull List variables, Map placeholderValues, - MultiDataSet batch, Collection requiredActivations, boolean training, At at){ - if(at == null){ - if(training) - at = At.defaultAt(Operation.TRAINING); - else - at = At.defaultAt(Operation.INFERENCE); - } - return output(variables, placeholderValues, batch, requiredActivations, Collections.emptyList(), at); - } - - /** - * Get the output of the session - i.e., perform inference/forward pass - * - * @param variables Name of the variables we want the arrays/activations for - * @param placeholderValues The placeholder values (if any). - * @param batch The batch data, used to call Listener.opExecution - * @param requiredActivations Additional activations that are required. Won't be outputed, but opExecution will be called. May be null. + * @param variables Name of the variables we want the arrays/activations for + * @param placeholderValues The placeholder values (if any). May be null. + * @param batch The batch data, used to call Listener.opExecution + * @param requiredActivations Additional activations that are required. Won't be outputed, but opExecution will be called. May be null. * @return The specified variable values, optionally in the specified workspace */ public Map output(@NonNull List variables, Map placeholderValues, - MultiDataSet batch, Collection requiredActivations, List listeners, At at) { + MultiDataSet batch, Collection requiredActivations, List listeners, At at) { + Preconditions.checkState(!variables.isEmpty() || !requiredActivations.isEmpty(), "Variables to perform forward pass for must not be empty"); - Preconditions.checkState(!variables.isEmpty(), "Variables to perform forward pass for must not be empty"); - - if(requiredActivations == null) + if (requiredActivations == null) requiredActivations = Collections.emptyList(); - if(at == null) + if (at == null) at = At.defaultAt(); //Step 0: validation - that variables exist, placeholders have arrays, etc @@ -177,44 +141,46 @@ public abstract class AbstractSession { Preconditions.checkState(sameDiff.variableMap().containsKey(s), "Requested output variable %s does not exist in SameDiff instance", s); } - placeholderValues = preprocessPlaceholders(placeholderValues); + Set reqOutputVariablesSet = new HashSet<>(variables); - //Clear state from past - availableForExec.clear(); - availableForExecSet.clear(); + placeholderValues = preprocessPlaceholders(placeholderValues, at); + + //Clear state from past iterations, if any + dt.clear(); subgraph.clear(); - execInputs.clear(); - execInputsAllIter.clear(); - execConstInputs.clear(); - nodeOutputs.clear(); //TODO eventually we'll have cache here for later execs... main challenge is detecting in-place array modifications and invalidating old results + subgraphOps.clear(); + nodeOutputs.clear(); //TODO eventually we'll have (optional) cache here for later execs... main challenge is detecting in-place array modifications and invalidating old results. And overall memory use... tensorArrays.clear(); //Step 1: determine subgraph structure we actually need to execute //Basic plan: work backwards from the variables we want, based on the graph structure, to work out what // we actually need to execute - List allRequired = new ArrayList<>(requiredActivations); + //TODO we'll optimize this and cache the results, only recalculating if the graph structure changes + Set userRequestedUnique = new HashSet<>(variables); + Set allRequired = new HashSet<>(requiredActivations); allRequired.addAll(variables); initSubgraph(allRequired); - //Step 1a: Check that we have required placeholders + //Step 2: Check that we have required placeholders List phNames = sameDiff.inputs(); - if(placeholderValues == null || !placeholderValues.keySet().containsAll(phNames)){ + if (placeholderValues == null || !placeholderValues.keySet().containsAll(phNames)) { /* We only have a subset of all placeholders Validate that we have all *required* placeholder values. Some might not be needed to calculate the requested outputs A placeholder is required if: (a) It's one of the requested outputs (b) It's required to calculate any of the ops in the subgraph + For example, we might have a label placeholder, and we're doing inference not training */ - for(String s : phNames){ + for (String s : phNames) { boolean required = false; - if(variables.contains(s)){ //TODO List.contains - O(N) + if (variables.contains(s)) { required = true; } - if(!required){ + if (!required) { Variable v = sameDiff.getVariables().get(s); - if(v.getInputsForOp() != null){ - for(String s2 : v.getInputsForOp()){ - if(subgraph.contains(s2)){ + if (v.getInputsForOp() != null) { + for (String s2 : v.getInputsForOp()) { + if (subgraph.contains(s2)) { //Placeholder is required required = true; break; @@ -223,200 +189,562 @@ public abstract class AbstractSession { } } - if(required && (placeholderValues == null || !placeholderValues.containsKey(s))){ - - // Some Keras layers (like GRU) do different things depending on whether the model is training. - // We provide this value directly. - if(s.endsWith("keras_learning_phase")){ - placeholderValues.put(s, (T) Nd4j.scalar(at.operation().isTrainingPhase())); - } else { - throw new IllegalStateException( - "An input placeholder \"" + s + "\" is required to calculate the requested outputs," + - " but a placeholder value was not provided"); - } + if (required && (placeholderValues == null || !placeholderValues.containsKey(s))) { + throw new IllegalStateException( + "An input placeholder \"" + s + "\" is required to calculate the requested outputs," + + " but a placeholder value was not provided"); } } } - //Step 2: execute in any order, until we have all required nodeOutputs + //Step 3: Mark the (required) variables, constants and placeholders as available via dependency tracker + //And also any "zero dependency" ops - i.e., those without any inputs + ExecStep start = new ExecStep(ExecType.EXEC_START, "", null); //Dummy dependency to trigger the variables and constants + for (SDVariable v : sameDiff.variables()) { + VariableType vt = v.getVariableType(); + if (vt == VariableType.VARIABLE || vt == VariableType.CONSTANT) { + ExecType et = vt == VariableType.VARIABLE ? ExecType.VARIABLE : ExecType.CONSTANT; + ExecStep es = new ExecStep(et, v.getVarName(), new FrameIter(OUTER_FRAME, 0, null)); + dt.addDependency(es, start); + + Variable var = sameDiff.getVariables().get(v.getVarName()); + if (var.getControlDeps() != null) { + addVarControlDeps(es, var); //Before this variable can be considered available for use, we need specified op to be executed + } + } + } + for (String s : phNames) { + ExecStep es = new ExecStep(ExecType.PLACEHOLDER, s, new FrameIter(OUTER_FRAME, 0, null)); + dt.addDependency(es, start); + + Variable var = sameDiff.getVariables().get(s); + if (var.getControlDeps() != null) { + addVarControlDeps(es, var); //Before this variable can be considered available for use, we need specified op to be executed + } + } + for (String s : zeroInputOpsInSubgraph) { + ExecStep es = new ExecStep(ExecType.OP, s, new FrameIter(OUTER_FRAME, 0, null)); + dt.addDependency(es, start); + } + dt.markSatisfied(start, true); + + + //Step 4: execute in any order, but not switching to new frame/iteration until all from current frame/iter ops + // are done - until we have all required nodeOutputs /* - The idea is simple: we start off with a set of "available to execute" variables - just the placeholders and - constants at this point. + The idea is simple: we start off with a set of "available to execute" variables - just the placeholders, + constants and variables (assuming no control dependencies) at the start of execution. Then, we remove an "available to execute" node and execute it. Execution may be: - (a) For constants and placeholders: just looking up the value - (b) For variables as outputs of ops: actually executing the op + (a) For constants, variable type SDVariables, and placeholders: just look up the value + (b) For variables as outputs of ops: actually execute the op After execution, we look at the graph structure and determine what that now executed/calculated variable is an input to. If all inputs are available for the op, we mark all output variables of that op as available for execution. + Both parts of this (tracking dependencies, and also what's now available to execute) are handled in the dependency tracker We stop computation once all the required outputs are available. At this point, subgraph may NOT be empty - for example, switch ops may cause entire branches of the graph to be skipped. */ - Map out = new HashMap<>(); - int step = 0; - while (out.size() < variables.size()) { - if(availableForExec.size() == 0){ - int missingCount = variables.size() - out.size(); - StringBuilder sb = new StringBuilder(); - sb.append("No variable are available for execution at step ") - .append(step).append(": ").append(missingCount).append(" values remaining"); - Set missing = new HashSet<>(); - for(String s : variables){ - if(!out.containsKey(s)){ - missing.add(s); - } + Map out = new HashMap<>(); //Outputs, returned to the user + int step = 0; //Number of execution steps + //Next 3: current execution frame + String currentFrame = OUTER_FRAME; + int currentFrameIter = 0; + FrameIter currParentFrame = null; + ExecStepPredicate predicate = new ExecStepPredicate(); + while (out.size() < userRequestedUnique.size()) { + if (!dt.hasNewAllSatisfied()) { + //Haven't got all of the outputs the user requested, but there's nothing left that we can execute. Should not happen. + execFailed(userRequestedUnique, out, step); + } + + //Get variable in the current frame/iteration and execute it's corresponding op + //If no more ops exist for the current frame/iter, we'll switch to the next frame/iter + //The idea is to not mix the order of execution of ops in different frames/iters - i.e., finish the current + // frame/iter before starting the next one + predicate.setCurrentFrame(currentFrame); + predicate.setCurrentFrameIter(currentFrameIter); + predicate.setCurrParentFrame(currParentFrame); + + ExecStep es = dt.getFirstNewAllSatisfiedMatching(predicate); + if (es == null) { + //We must have finished the current frame/iter, and are switching to the next one + es = dt.getNewAllSatisfied(); + } + + currentFrame = es.getFrameIter().getFrame(); + currentFrameIter = es.getFrameIter().getIteration(); + currParentFrame = es.getFrameIter().getParentFrame(); + + log.trace("Beginning execution step {}: {}", step, es); + + FrameIter outFrameIter; + boolean skipDepUpdate = false; //Only used for Switch ops, which have slighly different handling... + boolean skipMarkSatisfied = false; //Only for enter ops, because of different frame/iter + if (es.getType() == ExecType.CONSTANT || es.getType() == ExecType.VARIABLE) { + VarId vid = new VarId(es.getName(), OUTER_FRAME, 0, null); + T arr = getConstantOrVariable(es.getName()); + Preconditions.checkNotNull(arr, "Encountered null placeholder array for constant: %s", vid); + nodeOutputs.put(vid, arr); + outFrameIter = new FrameIter(OUTER_FRAME, 0, null); + if (allRequired.contains(es.getName())) { + //User requested const/variable as one of the outputs + out.put(es.getName(), arr); } - if(missingCount <= 10){ - sb.append(". Missing variables: "); - sb.append(missing); + } else if (es.getType() == ExecType.PLACEHOLDER) { + VarId vid = new VarId(es.getName(), OUTER_FRAME, 0, null); + nodeOutputs.put(vid, placeholderValues.get(es.getName())); + outFrameIter = new FrameIter(OUTER_FRAME, 0, null); + if (allRequired.contains(es.getName())) { + //User requested placeholder value as one of the outputs + out.put(es.getName(), placeholderValues.get(es.getName())); + } + } else if (es.getType() == ExecType.OP) { + String opName = es.getName(); + SameDiffOp op = sameDiff.getOps().get(opName); + DifferentialFunction o = op.getOp(); + + if (o instanceof Enter) { + //Enter op: output is variable in a new (specified) frame, iteration 0. + //Parent is current (input) frame + String outFrame = ((Enter) o).getFrameName(); + outFrameIter = new FrameIter(outFrame, 0, es.getFrameIter()); + } else if (o instanceof Exit) { + //Exit node forwards input to parent frame + String outFrame = es.getFrameIter().getParentFrame().getFrame(); + int outIter = es.getFrameIter().getParentFrame().getIteration(); + FrameIter outParentFrame = es.getFrameIter().getParentFrame().getParentFrame(); + outFrameIter = new FrameIter(outFrame, outIter, outParentFrame); + } else if (o instanceof NextIteration) { + //NextIteration op: forwards its single input to its output varible in the current frame, but increments the iteration number + outFrameIter = es.getFrameIter().clone(); + outFrameIter.setIteration(outFrameIter.getIteration()); } else { - sb.append(". First 10 missing variables: "); - Iterator iter = missing.iterator(); - for( int i=0; i<10 && iter.hasNext(); i++ ){ - if(i > 0) - sb.append(","); - sb.append(iter.next()); + //Standard ops - output variable has same frame and iteration number as the input(s) + //Also loopCond, merge, while, etc + outFrameIter = es.getFrameIter(); + } + + + //Resolve the inputs to this execution step (op) to actual arrays + Set inputs = null; + Set allIterInputs = null; + Set constAndPhInputs = null; + DependencyList dl = dt.getDependencies(es); + + List inputNames = op.getInputsToOp(); + if (inputNames != null && !inputNames.isEmpty()) { + inputs = new HashSet<>(); + allIterInputs = new HashSet<>(); + constAndPhInputs = new HashSet<>(); + List deps = dl.getDependencies(); + if (deps != null && !deps.isEmpty()) { + for (ExecStep dep : deps) { + switch (dep.getType()) { + case OP: + case SWITCH_L: + case SWITCH_R: + //The current execution step depends on one output of the op "dep" + SameDiffOp toExecOp = sameDiff.getOps().get(es.getName()); + List inputsToExecOp = toExecOp.getInputsToOp(); + SameDiffOp inputOp = sameDiff.getOps().get(dep.getName()); + List inputOpOutNames = inputOp.getOutputsOfOp(); + for (String s : inputsToExecOp) { + if (inputOpOutNames.contains(s)) { + VarId vid = new VarId(s, dep.getFrameIter().getFrame(), dep.getFrameIter().getIteration(), dep.getFrameIter().getParentFrame()); + inputs.add(vid); + } + } + break; + case VARIABLE: + inputs.add(new VarId(dep.getName(), OUTER_FRAME, 0, null)); + break; + case CONSTANT: + case PLACEHOLDER: + constAndPhInputs.add(dep.getName()); + break; + default: + throw new UnsupportedOperationException("Not yet implemented: " + dep.getType()); + } + } } } - String s = sb.toString(); - throw new IllegalStateException(s); - } - - //Get any variable and execute it's corresponding op - VarId varToExec = availableForExec.remove(); - availableForExecSet.remove(varToExec); - if (nodeOutputs.containsKey(varToExec)) { - //Already processed this one. May occur if execution was triggered by a different output of a multi-output op - //But we'll still update its descendants to ensure they are marked as available - if (variables.contains(varToExec.getVariable())) { //Check if required output - out.put(varToExec.getVariable(), nodeOutputs.get(varToExec)); - } - updateDescendentsForExec(step, varToExec); - continue; - } - - //Get inputs to this variable. May be actual op inputs, or just control dependencies - Set inputsToVar = execInputs.get(varToExec); - VarId allIterInputVar = newVarId(varToExec.getVariable(), varToExec.getFrame(), 0, varToExec.getParentFrame()); - Set inputsToVarAllIter = execInputsAllIter.get(allIterInputVar); - Set constPhForVar = execConstInputs.get(varToExec.getVariable()); - - log.trace("Beginning execution step {}: variable {}", step, varToExec); - - if (sameDiff.getVariable(varToExec.getVariable()).isPlaceHolder()) { - //Variable is placeholder: do lookup - nodeOutputs.put(varToExec, placeholderValues.get(varToExec.getVariable())); - updateDescendentsForExec(step, varToExec); //Check + mark descendants as available for exec - if (variables.contains(varToExec.getVariable())) { //Check if required output - out.put(varToExec.getVariable(), placeholderValues.get(varToExec.getVariable())); - } - } else if (sameDiff.getVariable(varToExec.getVariable()).isConstant() || - sameDiff.getVariable(varToExec.getVariable()).getVariableType() == VariableType.VARIABLE) { - //Variable is constant: do lookup - //OR variable is VARIABLE type - i.e., a trainable parameter... - T phArr = getConstantOrVariable(varToExec.getVariable()); - Preconditions.checkNotNull(phArr, "Encountered null placeholder array for constant: %s", varToExec); - nodeOutputs.put(varToExec, phArr); - updateDescendentsForExec(step, varToExec); //Check + mark descendants as available for exec - if (variables.contains(varToExec.getVariable())) { //Check if required output - out.put(varToExec.getVariable(), phArr); - } - } else if (sameDiff.getVariableOutputOp(varToExec.getVariable()) != null) { - //Variable is the output of an op -> execute op - String opName = sameDiff.getVariables().get(varToExec.getVariable()).getOutputOfOp(); + // Do execution of the op, in 2 steps + // (a) "Parameterize" the op - i.e., find and set the arrays on the op, allocate outputs, etc ready for execution + // (b) actually execute the operation + O parameterizedOp = getAndParameterizeOp(opName, outFrameIter, inputs, allIterInputs, constAndPhInputs, placeholderValues, reqOutputVariablesSet); + T[] opOutputValues = getOutputs(parameterizedOp, outFrameIter, inputs, allIterInputs, constAndPhInputs, listeners, at, batch, reqOutputVariablesSet); + List opOutVarNames = op.getOutputsOfOp(); - //Execute op - FrameIter frameIter = varToExec.toFrameIter(); - O parameterizedOp = getAndParameterizeOp(opName, frameIter, inputsToVar, inputsToVarAllIter, constPhForVar, placeholderValues); - T[] opOutputValues = getOutputs(parameterizedOp, frameIter, inputsToVar, inputsToVarAllIter, constPhForVar, listeners, at, batch); - - - //Post execution: work out what is now available for exec - String[] opOutputVarNames = sameDiff.getOpById(opName).outputVariablesNames(); - - Preconditions.checkState(opOutputValues.length == opOutputVarNames.length, "Unexpected number of outputs from executed op %s:" + + Preconditions.checkState(opOutputValues.length == opOutVarNames.size(), "Unexpected number of outputs from executed op %s:" + " got %s outputs when %s outputs were expected (%s)", parameterizedOp.getClass().getSimpleName(), opOutputValues.length, - opOutputVarNames.length, opOutputVarNames); + opOutVarNames.size(), opOutVarNames); - for (int i = 0; i < opOutputVarNames.length; i++) { - if (opOutputValues[i] == null && parameterizedOp instanceof Switch) { - //Skip null - for switch op only. Switch op forwards input to only one of its outputs - //All other ops should not + //Store the op outputs + for (int i = 0; i < opOutputValues.length; i++) { + if (opOutputValues[i] == null && op.getOp() instanceof Switch) { + //Switch op only forwards the input to one of the outputs continue; } - Preconditions.checkNotNull(opOutputValues[i], "Encountered null output (output %s) for op %s at execution step %s", i, parameterizedOp.getClass().getSimpleName(), step); + String n = opOutVarNames.get(i); + VarId vid = new VarId(n, outFrameIter.getFrame(), outFrameIter.getIteration(), outFrameIter.getParentFrame()); + nodeOutputs.put(vid, opOutputValues[i]); - VarId outputVarId; - boolean addDummyOutput = false; - if (parameterizedOp instanceof Enter) { - //Enter op: output is variable in a new (specified) frame, iteration 0. - String frame = ((Enter) parameterizedOp).getFrameName(); - boolean isConstant = ((Enter) parameterizedOp).isConstant(); - FrameIter outParentFrame = varToExec.getParentFrame(); - if(isConstant && outParentFrame != null){ - //For enter nodes that are constants, we want iteration 0 in all frames in the heirarchy - //For example, const -> Enter(a) -> Enter(b) -> op; in this case, the input to Op (at any frame/iteration) should should - // be the constant value - which is recorded as (frame="a",iter=0,parent=(frame="b",iter=0)) - outParentFrame = outParentFrame.clone(); - FrameIter toZero = outParentFrame; - while(toZero != null){ - toZero.setIteration(0); - toZero = toZero.getParentFrame(); - } - } - outputVarId = newVarId(opOutputVarNames[i], frame, 0, outParentFrame); - addDummyOutput = true; - } else if (parameterizedOp instanceof Exit) { - //Exit node forwards input to parent frame (which is already reflected in varToExec) - outputVarId = newVarId(opOutputVarNames[i], varToExec.getFrame(), varToExec.getIteration(), varToExec.getParentFrame()); - addDummyOutput = true; - } else if (parameterizedOp instanceof NextIteration) { - //NextIteration op: forwards its single input to its output varible in the current frame, but increments the iteration number - //Note that varToExec has already had its iteration number incremented by 1 (relative to its input) in updateDescendentsForExec... so don't increment here - outputVarId = newVarId(opOutputVarNames[i], varToExec.getFrame(), varToExec.getIteration(), varToExec.getParentFrame()); - addDummyOutput = true; - } else if (parameterizedOp instanceof LoopCond) { - //LoopCond just forwards input to output - outputVarId = newVarId(opOutputVarNames[i], varToExec.getFrame(), varToExec.getIteration(), varToExec.getParentFrame()); - addDummyOutput = true; - } else { - //Standard ops - output variable has same frame and iteration number as the input(s) - outputVarId = newVarId(opOutputVarNames[i], varToExec.getFrame(), varToExec.getIteration(), varToExec.getParentFrame()); - } - - if(addDummyOutput){ - //For ops like enter/exit/nextiteration, these don't have a real output for that node - //But, we still want an entry in nodeOutputs, which we also use for checking if an op has already been executed - nodeOutputs.put(newVarId(opOutputVarNames[i], varToExec.getFrame(), varToExec.getIteration(), varToExec.getParentFrame()), null); - } - - nodeOutputs.put(outputVarId, opOutputValues[i]); - updateDescendentsForExec(step, outputVarId); //Check + mark descendants as available for exec - - if (variables.contains(opOutputVarNames[i])) { //Check if required output - out.put(opOutputVarNames[i], opOutputValues[i]); + if (allRequired.contains(n)) { + out.put(n, opOutputValues[i]); } } + + //Post execution: update dependency tracker so we know what is available to execute next, given we now + // have these new values + if (o instanceof Switch) { + /* + Switch is a special case: only one output/branch is considered to exist post execution. + Unlike every other type of op, only 1 of 2 output arrays is actually executed. + For dependency tracking purposes, this is why we have SWITCH_L and _R execution types. + If we just depended on the op, the dependency tracker would incorrectly conclude that ops relying on + both branches (i.e., including the unavailable one) can now be executed + */ + skipDepUpdate = true; + skipMarkSatisfied = true; + int nullCount = (opOutputValues[0] == null ? 1 : 0) + (opOutputValues[1] == null ? 1 : 0); + Preconditions.checkState(nullCount == 1, "Expected exactly one output to be present for switch ops, got %s", nullCount); + boolean left = opOutputValues[0] != null; + ExecStep branch; + if (left) { + branch = new ExecStep(ExecType.SWITCH_L, es.getName(), es.getFrameIter()); + } else { + branch = new ExecStep(ExecType.SWITCH_R, es.getName(), es.getFrameIter()); + } + updateDescendantDeps(branch, outFrameIter); + dt.markSatisfied(branch, true); + } else if (o instanceof Enter) { + //Enter op: we want to say that the inner frame is executed... + skipDepUpdate = true; + skipMarkSatisfied = true; + Enter e = (Enter) o; + FrameIter fi = new FrameIter(e.getFrameName(), 0, es.getFrameIter()); + ExecStep exec = new ExecStep(ExecType.OP, es.getName(), fi); + updateDescendantDeps(exec, fi); + dt.markSatisfied(exec, true); + } else if (o instanceof Exit) { + //Exit op: we want to say that the parent frame is executed... + skipDepUpdate = true; + skipMarkSatisfied = true; + FrameIter fi = es.getFrameIter().getParentFrame(); + ExecStep exec = new ExecStep(ExecType.OP, es.getName(), fi); + updateDescendantDeps(exec, fi); + dt.markSatisfied(exec, true); + } + + /* + Edge case for TensorFlow import control dependencies: for some reason, TF allows op control dependencies + like /while/x -> SomeConstant - i.e., a constant depending on something inside a scope. + This should be handled with an enter op, but TF doesn't always use this :/ + Note that this is equivalent to marking the control dependency as satisfied on the first iteration + TODO double check that this is exactly the same behaviour as TF - otherwise this approach might fail in + some rare cases that rely on the constant/variable not being available + */ + List cdFor = op.getControlDepFor(); + if (cdFor != null) { + ExecStep cdEs = new ExecStep(ExecType.CONTROL_DEP, opName, null); + if (!dt.isSatisfied(cdEs)) { + dt.markSatisfied(cdEs, true); + } + } + } else { - Variable v = sameDiff.getVariables().get(varToExec.getVariable()); - throw new IllegalStateException("Unable to execute variable " + varToExec + " of type " + v.getVariable().getVariableType()); + //Should never happen + throw new RuntimeException("Unknown ExecStep: " + es); } + + //Standard ops + if (!skipDepUpdate) { + updateDescendantDeps(es, outFrameIter); + } + if (!skipMarkSatisfied) { + dt.markSatisfied(es, true); + } + step++; } + //TODO we should clear the node outputs map to get rid of the invalid (closed, out of workspace, etc) arrays - //TODO under what circumstances should we clear the nodeOutputs map? - //TODO when should we close the workspace? (Might want to leave it open if we expect to re-use) - + out = postProcessOutput(out); //Hook-in for subclass sessions, if needed return out; } - protected void initSubgraph(List variables) { + /** + * Add the control dependency from Op -> variable + * + * @param es Execution step for the variable + * @param v Variable + */ + protected void addVarControlDeps(ExecStep es, Variable v) { + List cds = v.getControlDeps(); + if (cds != null) { + for (String s : cds) { + ExecStep controlES = new ExecStep(ExecType.CONTROL_DEP, s, null); + dt.addDependency(es, controlES); //Before this variable can be considered available for use, we need specified op to be executed + } + } + } + + /** + * Execution failed - can't calculate all requested outputs, and there's nothing left to calculate. + * Throws an exception with a useful message + * + * @param userRequestedUnique All outputs that the user requseted + * @param out Current outputs + * @param step Execution step + */ + protected void execFailed(Set userRequestedUnique, Map out, int step) { + int missingCount = userRequestedUnique.size() - out.size(); + StringBuilder sb = new StringBuilder(); + sb.append("No variable are available for execution at step ") + .append(step).append(": ").append(missingCount).append(" values remaining"); + Set missing = new HashSet<>(); + for (String s : userRequestedUnique) { + if (!out.containsKey(s)) { + missing.add(s); + } + } + if (missingCount <= 10) { + sb.append(". Missing variables: "); + sb.append(missing); + } else { + sb.append(". First 10 missing variables: "); + Iterator iter = missing.iterator(); + for (int i = 0; i < 10 && iter.hasNext(); i++) { + if (i > 0) + sb.append(","); + sb.append(iter.next()); + } + } + String s = sb.toString(); +// System.out.println(sameDiff.summary()); + throw new IllegalStateException(s); + } + + /** + * Update the descendant dependencies + * So if the graph structure is X -> A, then add all (X,Y,Z,...) -> A to the dependency tracker + * This is for a specific frame and iteration, for both sides of the dependency (in and out) + * + * @param justExecuted The execution step that has just completed + * @param outFrameIter The frame/iteration of the output + */ + protected void updateDescendantDeps(ExecStep justExecuted, FrameIter outFrameIter) { + ExecType t = justExecuted.getType(); + String n = justExecuted.getName(); + if (justExecuted.getType() == ExecType.OP) { + SameDiffOp op = sameDiff.getOps().get(n); + List outNames = op.getOutputsOfOp(); + for (String s : outNames) { + Variable v = sameDiff.getVariables().get(s); + List inputsToOps = v.getInputsForOp(); + if (inputsToOps != null) { + for (String opName : inputsToOps) { + if (subgraphOps.contains(opName)) { + //We've just executed X, and there's dependency X -> Y + //But, there also might be a Z -> Y that we should mark as needed for Y + addDependenciesForOp(opName, outFrameIter); + } + } + } + + + //Also add control dependencies (variable) + List cdForOps = v.getControlDepsForOp(); + if (cdForOps != null) { + for (String opName : cdForOps) { + if (subgraphOps.contains(opName)) { + //We've just executed X, and there's dependency X -> Y + //But, there also might be a Z -> Y that we should mark as needed for Y + addDependenciesForOp(opName, outFrameIter); + } + } + } + } + } else if (t == ExecType.VARIABLE || t == ExecType.CONSTANT || t == ExecType.PLACEHOLDER) { + Variable v = sameDiff.getVariables().get(n); + List inputsToOps = v.getInputsForOp(); + if (inputsToOps != null) { + for (String opName : inputsToOps) { + if (subgraphOps.contains(opName)) { + addDependenciesForOp(opName, outFrameIter); + } + } + } + } else if (justExecuted.getType() == ExecType.SWITCH_L || justExecuted.getType() == ExecType.SWITCH_R) { + SameDiffOp op = sameDiff.getOps().get(n); + List outNames = op.getOutputsOfOp(); + String branchVarName = (justExecuted.getType() == ExecType.SWITCH_L ? outNames.get(0) : outNames.get(1)); + Variable v = sameDiff.getVariables().get(branchVarName); + List inputsToOps = v.getInputsForOp(); + if (inputsToOps != null) { + for (String opName : inputsToOps) { + if (subgraphOps.contains(opName)) { + //We've just executed X, and there's dependency X -> Y + //But, there also might be a Z -> Y that we should mark as needed for Y + addDependenciesForOp(opName, outFrameIter); + } + } + } + } else { + throw new UnsupportedOperationException("Unknown or not yet implemented exec type: " + justExecuted); + } + } + + /** + * Suppose operation X has just been executed. + * For X -> someOp, add all dependencies for someOp, i.e., all Z -> someOp + * (which includes X, but may not only be X) + * + * @param opName Name of the op + * @param depFrameIter Frame/iteration of the op instance to be executed + */ + protected void addDependenciesForOp(String opName, FrameIter depFrameIter) { + SameDiffOp op = sameDiff.getOps().get(opName); + List inputs = op.getInputsToOp(); + List cdOps = op.getControlDeps(); + List cdVars = op.getVarControlDeps(); + + ExecStep es = new ExecStep(ExecType.OP, opName, depFrameIter); + if (!(op.getOp() instanceof NextIteration) && dt.hasDependency(es)) { + //Already processed this once. We only add dependencies once per op (for a given frame/iteration) + return; + } + + if (op.getOp() instanceof Merge) { + //Merge ops are a special case: they can be executed with EITHER ONE of the inputs available - unlike every + // other op, we don't need all inputs, just one, before it can be executed + Variable v0 = sameDiff.getVariables().get(inputs.get(0)); + Variable v1 = sameDiff.getVariables().get(inputs.get(1)); + + ExecStep or0 = getExecStepForVar(v0.getName(), depFrameIter); + ExecStep or1 = getExecStepForVar(v1.getName(), depFrameIter); + dt.addOrDependency(es, or0, or1); + } else if (op.getOp() instanceof NextIteration) { + //For NextIteration, dependencies should be of the form X(iter) -> NextIter(iter+1) + FrameIter fi = depFrameIter.clone(); + fi.setIteration(fi.getIteration() + 1); + es = new ExecStep(ExecType.OP, opName, fi); + for (String s : inputs) { + ExecStep req = getExecStepForVar(s, depFrameIter); + dt.addDependency(es, req); + } + } else { + for (String s : inputs) { + ExecStep req = getExecStepForVar(s, depFrameIter); + dt.addDependency(es, req); + } + } + + if (cdOps != null) { + for (String s : cdOps) { + ExecStep req = getExecStepForVar(s, depFrameIter); + dt.addDependency(es, req); + } + } + + if (cdVars != null) { + for (String s : cdVars) { + + } + } + } + + /** + * Get the ExecStep for the given variable, given execution is happening at the specified frame/iteration + */ + protected ExecStep getExecStepForVar(String varName, FrameIter frameIter) { + Variable v = sameDiff.getVariables().get(varName); + VariableType vt = v.getVariable().getVariableType(); + if (vt == VariableType.VARIABLE) { + return new ExecStep(ExecType.VARIABLE, v.getVariable().getVarName(), new FrameIter(OUTER_FRAME, 0, null)); + } else if (vt == VariableType.PLACEHOLDER) { + return new ExecStep(ExecType.PLACEHOLDER, v.getVariable().getVarName(), new FrameIter(OUTER_FRAME, 0, null)); + } else if (vt == VariableType.CONSTANT) { + return new ExecStep(ExecType.CONSTANT, v.getVariable().getVarName(), new FrameIter(OUTER_FRAME, 0, null)); + } else { + //Array type. Must be output of an op + String outOfOp = v.getOutputOfOp(); + SameDiffOp sdo = sameDiff.getOps().get(outOfOp); + if (sdo.getOp() instanceof Switch) { + //For dependency tracking purposes, we track left and right output branches of switch op separately + //Otherwise, ops depending both branches will be marked as available if we just rely on "op has been executed" + List opOutputs = sdo.getOutputsOfOp(); + int idx = opOutputs.indexOf(v.getName()); + if (idx == 0) { + //Left branch + return new ExecStep(ExecType.SWITCH_L, outOfOp, frameIter); + } else if (idx == 1) { + //Right branch + return new ExecStep(ExecType.SWITCH_R, outOfOp, frameIter); + } else { + //Should never happen + throw new IllegalStateException("Expected variable \"" + v.getName() + "\" to be an output of operation \"" + + outOfOp + "\", but op output variables are: " + opOutputs); + } + } else if (sdo.getOp() instanceof Enter) { + Enter e = (Enter) sdo.getOp(); + + //For enter ops, "constant=true" enter ops are available for ALL iterations, hence use iter=0 + //For constant=false, these are only available at iteration 0 - so use *current* iteration, same as all other ops + // (which is this case, won't be triggered on iter > 0 - as desired/expected) + if (e.isConstant()) { + FrameIter fi = frameIter.clone(); + fi.setIteration(0); + + //Nested constant enter case: Iteration 0 all the way down... + String inVarName = sdo.getInputsToOp().get(0); + FrameIter parentFrame = fi.getParentFrame(); + while (parentFrame != null) { + Variable var = sameDiff.getVariables().get(inVarName); + if (var.getOutputOfOp() != null) { + String opName = var.getOutputOfOp(); + SameDiffOp sdo2 = sameDiff.getOps().get(opName); + if (sdo2.getOp() instanceof Enter) { + Enter e2 = (Enter) sdo.getOp(); + if (e2.isConstant()) { + parentFrame.setIteration(0); + parentFrame = parentFrame.getParentFrame(); + inVarName = sdo2.getInputsToOp().get(0); + } else { + break; + } + } else { + break; + } + } else { + break; + } + } + + return new ExecStep(ExecType.OP, outOfOp, fi); + } + + //Intentional fall-through to default case + } + return new ExecStep(ExecType.OP, outOfOp, frameIter); + } + } + + /** + * Initialize the subgraph - the subgraph and subgraphOps sets + * This works our what ops and variables we might need to execute to get the requested outputs. + * In general, this is a subset of the graph. + * + * @param variables Set of output variables we need + */ + protected void initSubgraph(Set variables) { //Step 1: determine subgraph structure we actually need to execute Queue processingQueue = new LinkedList<>(variables); @@ -434,21 +762,20 @@ public abstract class AbstractSession { // until after execution of some other ops (for example, in conditional operations) numInputs += controlDeps.size(); } - if (numInputs == 0) { - VarId vid = newVarId(varName, OUTER_FRAME, 0, null); - if(!availableForExecSet.contains(vid)) { - availableForExec.add(vid); - availableForExecSet.add(vid); - } - execInputs.put(vid, new HashSet()); + if (numInputs == 0 && opName != null) { + zeroInputOpsInSubgraph.add(opName); } subgraph.add(varName); - if(controlDeps != null){ + if (opName != null) { + subgraphOps.add(opName); + } + + if (controlDeps != null) { //If variable has control dependencies, it's not available right away... to make it available, // we need the "inputs" to be available first. This is mainly used for TF import. - for(String s : controlDeps){ - if(!subgraph.contains(s)){ + for (String s : controlDeps) { + if (!subgraph.contains(s)) { processingQueue.add(s); } } @@ -477,359 +804,28 @@ public abstract class AbstractSession { } } - /** - * This method should be called for a variable once it's array is ready for use. - * For example, post op execution, etc - * - * @param execStep Current execution step (mainly for debugging) - * @param executedVar Variable that was just executed - */ - protected void updateDescendentsForExec(int execStep, VarId executedVar) { - String varName = executedVar.getVariable(); - Variable var = sameDiff.getVariables().get(executedVar.getVariable()); - //Find any ops (or variables with control dependencies) that this is required for execution of and check if now available for exec - List l = sameDiff.getVariables().get(executedVar.getVariable()).getInputsForOp(); - String[] inputForOps = l == null ? null : l.toArray(new String[l.size()]); //Just executed variable is input to these ops - List controlDepForVars = var.getControlDepsForVar(); //Just executed variable is a control dependency for these variables - List controlDepForOps = var.getControlDepsForOp(); //Just executed variable is a control dependency for these ops - - - SDVariable v = var.getVariable(); - boolean isConstOrPhInput = v.isPlaceHolder() || v.isConstant(); - - //After a variable becomes available, we should look at the ops this is an input to, and check if we can execute this op now... - if (inputForOps != null) { - for (String opName : inputForOps) { - - DifferentialFunction fn = sameDiff.getOpById(opName); - if (fn instanceof Merge) { - //Merge op: available for execution when *any* of its inputs are available. But only mark it for exec once... - List opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp(); - Preconditions.checkState(opOutputs.size() == 1, "Expected only 1 output variable for merge op, got %s", opOutputs); - VarId outVarId = newVarId(opOutputs.get(0), executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - if (!nodeOutputs.containsKey(outVarId) && subgraph.contains(outVarId.getVariable()) && !availableForExecSet.contains(outVarId)) { - availableForExec.add(outVarId); - availableForExecSet.add(outVarId); - log.trace("Marked merge op ({}) variable {} as available for execution: input {} is now available", opName, outVarId, executedVar); - } - - //Mark that we need the specified input to calculate this output - addToExecInputs(isConstOrPhInput, executedVar, outVarId); - continue; - } else if (fn instanceof Enter) { - //Enter node: available for exec when any of its inputs are available for exec - // Note input feeds from one frame to another - List opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp(); - Preconditions.checkState(opOutputs.size() == 1, "Expected only 1 output variable for enter op, got %s", opOutputs); - Enter e = (Enter) fn; - boolean isConstant = e.isConstant(); - VarId outVarId = newVarId(opOutputs.get(0), e.getFrameName(), 0, executedVar.toFrameIter()); //Note: parent frame of output op is enter var's *current* frame - - if(isConstant && executedVar.getParentFrame() != null){ - //For enter nodes that are constants, we want iteration 0 in all frames in the heirarchy - //For example, const -> Enter(a) -> Enter(b) -> op; in this case, the input to Op (at any frame/iteration) should should - // be the constant value - which is recorded as (frame="a",iter=0,parent=(frame="b",iter=0)) - outVarId.setParentFrame(outVarId.getParentFrame().clone()); - FrameIter fi = outVarId.getParentFrame(); - while(fi != null){ - fi.setIteration(0); - fi = fi.getParentFrame(); - } - } - - if (!nodeOutputs.containsKey(outVarId) && subgraph.contains(outVarId.getVariable()) && !availableForExecSet.contains(outVarId)) { - availableForExec.add(outVarId); - availableForExecSet.add(outVarId); - log.trace("Marked enter op ({}) variable {} as available for execution: input {} is now available", opName, outVarId, executedVar); - } - - //Also record the parent frame: we'll need this when we get to the corresponding exit ops - frameParents.put(e.getFrameName(), executedVar.toFrameIter()); - - //Mark that we need the specified input to calculate this output - addToExecInputs(isConstOrPhInput, executedVar, outVarId); - continue; - } else if (fn instanceof Exit) { - //Exit node forwards input to parent frame - List opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp(); - FrameIter parentFrame = frameParents.get(executedVar.getFrame()); - Preconditions.checkNotNull(parentFrame, "Parent frame must not be null for exit op: variable to exec is %s", executedVar); - - VarId outVarId = new VarId(opOutputs.get(0), parentFrame.getFrame(), parentFrame.getIteration(), executedVar.getParentFrame().getParentFrame()); //Parent frame of output is parent of current parent - if (!nodeOutputs.containsKey(outVarId) && subgraph.contains(outVarId.getVariable()) && !availableForExecSet.contains(outVarId)) { - availableForExec.add(outVarId); - availableForExecSet.add(outVarId); - log.trace("Marked Exit op ({}) variable {} as available for execution: input {} is now available", opName, outVarId, executedVar); - } - - addToExecInputs(isConstOrPhInput, executedVar, outVarId); - continue; - } else if (fn instanceof NextIteration) { - //NextIteration is available for execution when its single input is available - //NextIteration op: forwards its single input to the output of the current frame, but increments the iteration number - List opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp(); - Preconditions.checkState(opOutputs.size() == 1, "Expected exactly 1 output for NextIteration op: got %s", opOutputs); - VarId outVarId = newVarId(opOutputs.get(0), executedVar.getFrame(), executedVar.getIteration() + 1, executedVar.getParentFrame()); - - if (!nodeOutputs.containsKey(outVarId) && subgraph.contains(outVarId.getVariable()) && !availableForExecSet.contains(outVarId)) { - availableForExec.add(outVarId); - availableForExecSet.add(outVarId); - log.trace("Marked NextIteration op ({}) variable {} as available for execution: input {} is now available", opName, outVarId, executedVar); - } - - //Mark that we need the specified input to calculate this output - addToExecInputs(isConstOrPhInput, executedVar, outVarId); - continue; - } - //Note for LoopCond: just forwards input to output - so basically handle it the same as other ops here - - - //Can execute this op - and hence get it's output variables - if all inputs (and control deps) are available - String[] inputsThisOp = fn.argNames(); - boolean allInputsAvailable = true; - if (inputsThisOp != null) { - allInputsAvailable = allInputsAvailable(execStep, inputsThisOp, executedVar); - } - - //Check Op control dependencies - List opControlDeps = sameDiff.getOps().get(opName).getControlDeps(); - if (opControlDeps != null && allInputsAvailable) { - for (String cd : opControlDeps) { - VarId vcd = newVarId(cd, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - if (!nodeOutputs.containsKey(vcd)) { - allInputsAvailable = false; - break; - } - } - } - - List opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp(); - if (opOutputs != null) { - - for (String s : opOutputs) { - //The input (for normal ops - not Enter/Exit/NextIteration) have the same frame and iteration number as the just executed var - //Exception 1 to this: constants. If variable is a constant, then it's always iteration 0 of the main frame (unless variable control dep exists) - //Exception 2 to this: placeholders. As above - SDVariable sdv = sameDiff.getVariable(s); - Variable variable = sameDiff.getVariables().get(s); - VarId outVarId; - if (sdv.isConstant() || sdv.isPlaceHolder()) { - //Constant - if(variable.getControlDeps() == null || var.getControlDeps().isEmpty()){ - //Standard case - do a lookup of placeholder/constant - outVarId = newVarId(s, OUTER_FRAME, 0, null); - } else { - //Edge case: control dependency x -> constant exists - //We should look up based on x's frame/iteration - outVarId = newVarId(s, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - } - } else { - //Normal (non-constant) - outVarId = newVarId(s, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - } - - //Mark that we need the specified input to calculate this output - addToExecInputs(isConstOrPhInput, executedVar, outVarId); - - //Check variable control dependencies, for each of the op outputs - if(allInputsAvailable && variable.getControlDeps() != null && !variable.getControlDeps().isEmpty()){ - //If one of the op outputs has a control dependency input, make sure this is available - // before executing the op - //For example, if z=add(x,y) and control dependency A->z exists, then don't execute op until A is available - for(String cd : variable.getControlDeps()){ - Variable cdVar = sameDiff.getVariables().get(cd); - VarId cdVarId = null; - if (cdVar.getVariable().isConstant() || cdVar.getVariable().isPlaceHolder()) { - //Constant - if(variable.getControlDeps() == null || var.getControlDeps().isEmpty()){ - //Standard case - do a lookup of placeholder/constant - cdVarId = newVarId(cd, OUTER_FRAME, 0, null); - } else { - //Edge case: control dependency x -> constant -> thisOutput exists - //We should look up based on x's frame/iteration - cdVarId = newVarId(cd, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - } - } else { - //Normal (non-constant) - cdVarId = newVarId(cd, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - } - allInputsAvailable &= nodeOutputs.containsKey(cdVarId); - if(!allInputsAvailable) - break; - } - } - } - - if (allInputsAvailable) { - //Op can be executed -> variables as output are available for exec - - for (String s : opOutputs) { - if (!subgraph.contains(s)) - continue; //Don't need this variable to calculate requested outputs - so don't mark as available for execution - VarId vid = newVarId(s, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - if(!availableForExecSet.contains(vid)) { - availableForExec.add(vid); - availableForExecSet.add(vid); - log.trace("Marked variable as available for execution: {} - output of op {} ({}) with op inputs {}", vid, opName, - fn.getClass().getSimpleName(), (inputsThisOp == null ? "" : Arrays.toString(inputsThisOp))); - } - } - } - } - - } - } - - //Also check variable control dependencies... if control dependency varX->varY exists and varY is a constant/placeholder/variable, - // then it's not going to be triggered by the op-based check above - if(controlDepForVars != null){ - for(String s : controlDepForVars){ - if (!subgraph.contains(s)) - continue; //Don't need this variable to calculate requested outputs - so don't mark as available for execution - - SDVariable depFor = sameDiff.getVariable(s); - if(depFor.getVariableType() != VariableType.ARRAY){ - //Control dependency executedVar -> s exists, where "s" is not the output of an op - //Even thought this is a constant, we'll inherit the frame and iteration from the control dependency - // otherwise, we lose this frame/iteration information for any downstream variables using the constant within a frame - VarId outVarId = newVarId(s, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - if(!availableForExecSet.contains(outVarId)) { - availableForExec.add(outVarId); - availableForExecSet.add(outVarId); - log.trace("Marked variable as available for execution: {} - control dependency {} -> {} exists", outVarId, executedVar.getVariable(), s); - } - } else { - //Another edge case: OpX has output varY (with no inputs), and control dependency executedVar -> varY exists - //We should check if OpX is now available for execution... - //Similarly, if we have OpX with inputs, but we're only waiting on a varible control dependency Z -> X - // then we might not get triggered as available for exec above either - String opName = sameDiff.getVariables().get(s).getOutputOfOp(); - if(opName != null){ - SameDiffOp op = sameDiff.getOps().get(opName); - boolean allInputsAvailable = true; - if(op.getInputsToOp() != null && !op.getInputsToOp().isEmpty()){ - List inputList = op.getInputsToOp(); - allInputsAvailable = allInputsAvailable(execStep, inputList.toArray(new String[inputList.size()]), executedVar); - } - - if(allInputsAvailable && op.getControlDeps() != null){ - for(String cd : op.getControlDeps()){ - VarId vid = newVarId(cd, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); //Note: is array type, therefore has same frame/iter as parent - allInputsAvailable &= nodeOutputs.containsKey(vid); - if(!allInputsAvailable) - break; - } - } - if(allInputsAvailable){ - for(String opOutput : op.getOutputsOfOp()){ - Variable v2 = sameDiff.getVariables().get(opOutput); - if(v2.getControlDeps() != null){ - for(String s2 : v2.getControlDeps()){ - VarId vid = newVarId(s2, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); //Note: is array type, therefore has same frame/iter as parent - allInputsAvailable &= nodeOutputs.containsKey(vid); - if(!allInputsAvailable) - break; - } - } - } - } - - if(allInputsAvailable){ - VarId outVarId = newVarId(s, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - if(!availableForExecSet.contains(outVarId)) { - availableForExec.add(outVarId); - log.trace("Marked variable as available for execution: {} - is output of op {} with no inputs (but has control dependencies)", outVarId, op.getName()); - } - } - } - } - } - } - - //Edge case: if control dependency varX->opY exists, and opY doesn't have any inputs, it also can't be triggeered - // (made available for execution) by any of the previous checks. For any ops that DO have inputs, they will - // be triggered already - if(controlDepForOps != null){ - for(String opName : controlDepForOps){ - SameDiffOp op = sameDiff.getOps().get(opName); - if(op.getInputsToOp() == null || op.getInputsToOp().isEmpty()){ - for(String out : op.getOutputsOfOp()){ - if (!subgraph.contains(out)) - continue; //Don't need this variable to calculate requested outputs - so don't mark as available for execution - - //TODO is it possible to have both variable and op control dependencies?? - VarId outVarId = newVarId(out, OUTER_FRAME, 0, null); - if(!availableForExecSet.contains(outVarId)) { - availableForExec.add(outVarId); - availableForExecSet.add(outVarId); - log.trace("Marked variable as available for execution: {} - op control dependency variable {} -> op {} exists", outVarId, executedVar.getVariable(), opName); - } - } - } - } - } - } - - protected boolean allInputsAvailable(int execStep, String[] inputsThisOp, VarId executedVar){ - for (String in : inputsThisOp) { - //The input (for normal ops - not Enter/Exit/NextIteration) have the same frame and iteration number as the just executed var - //Exception 1 to this: constants. If variable is a constant, then it's always iteration 0 of the main frame (unless variable control dep exists) - //Exception 2 to this: placeholders. As above - //TODO Add SameDiff.isConstant(String) method... or SDVariable.isConstant() (or both) - SDVariable sdv = sameDiff.getVariable(in); - Variable variable = sameDiff.getVariables().get(in); - VarId vid; - boolean nestedWhile = false; - if (sdv.isConstant() || sdv.isPlaceHolder()) { - //Constant - if(variable.getControlDeps() == null || variable.getControlDeps().isEmpty()){ - //Standard case - do a lookup of placeholder/constant - vid = newVarId(in, OUTER_FRAME, 0, null); - } else { - //Edge case: control dependency x -> constant exists - //We should look up based on x's frame/iteration - vid = newVarId(in, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - } - } else { - //Normal (non-constant) - //Edge case: "Enter" nodes always have iteration 0 by definition. In some TF graphs/loops, the enter node - // is used in multiple iterations (like, a constant in a loop condition) - not just the first iteration - int iter = executedVar.getIteration(); - FrameIter parentFrame = executedVar.getParentFrame(); - if(sdv.getVariableType() == VariableType.ARRAY && sameDiff.getOps().get(variable.getOutputOfOp()).getOp() instanceof Enter){ - iter = 0; - Enter e = (Enter)sameDiff.getOps().get(variable.getOutputOfOp()).getOp(); - if(e.isConstant()){ - //For enter nodes that are constants, we want iteration 0 in all frames in the heirarchy - //For example, const -> Enter(a) -> Enter(b) -> op; in this case, the input to Op (at any frame/iteration) should should - // be the constant value - which is recorded as (frame="a",iter=0,parent=(frame="b",iter=0)) - parentFrame = parentFrame.clone(); - FrameIter toZero = parentFrame; - while(toZero != null){ - toZero.setIteration(0); - toZero = toZero.getParentFrame(); - } - } - } - vid = newVarId(in, executedVar.getFrame(), iter, parentFrame); - } - if (!nodeOutputs.containsKey(vid)) { - return false; - } - } - return true; - } - /** * Preprocess the placeholder values, if required. * Mainly reserved for casting in the case of InferenceSession + * * @param placeholders Placeholders to preprocess. * @return Preprocessed placeholders */ - protected Map preprocessPlaceholders(Map placeholders){ + protected Map preprocessPlaceholders(Map placeholders, At at) { return placeholders; } + /** + * Post process the session output values, if required. + * Override if required in session subclasses + * + * @param output Output to be returned to the user + * @return Post processed output + */ + protected Map postProcessOutput(Map output) { + return output; + } + /** * Get the constant or variable output - for example, constant array or constant shape. * Note that both constants and variables (i.e., VariableType.CONSTANT and VariableType.VARIABLE) are the same @@ -848,9 +844,11 @@ public abstract class AbstractSession { * @param inputs The inputs to the op (excluding constants/placeholders) - for the specific frame + iteration * @param allIterInputs The inputs - those that are not iteration-specific (mainly Enter op vars, which might be used in all iterations but are only executed once on iter 0) * @param constAndPhInputs The constant and placeholder inputs - used for all frames/iterations + * @param allReqVariables All required variables requested for the current session execution (not just the current op outputs) * @return The parameterized op */ - public abstract O getAndParameterizeOp(String opName, FrameIter frameIter, Set inputs, Set allIterInputs, Set constAndPhInputs, Map placeholderValues); + public abstract O getAndParameterizeOp(String opName, FrameIter frameIter, Set inputs, Set allIterInputs, Set constAndPhInputs, + Map placeholderValues, Set allReqVariables); /** * Execute the op - calculate INDArrays, or shape info, etc @@ -858,88 +856,49 @@ public abstract class AbstractSession { * @param op Operation to exit. This should be parameterized (i.e., all inputs set) * @param outputFrameIter The frame and iteration of the outputs * @param inputs The specific input arrays for the op + * @param allReqVariables All required variables requested for the current session execution (not just the current op outputs) * @return The outputs of the op */ public abstract T[] getOutputs(O op, FrameIter outputFrameIter, Set inputs, Set allIterInputs, Set constAndPhInputs, - List listeners, At at, MultiDataSet batch); + List listeners, At at, MultiDataSet batch, Set allReqVariables); /** - * This method is used to record that the specified input is required for calculating the specified output. - * While the graph structure itself provides us with the (input vars) -> op -> (output vars) type structure, it - * doesn't tell us exactly which array copy (i.e., variable + frame + iteration) to use as which copy of the output - * variable (variable + frame + iteration). - *

- * This method is basically used to store information we need to parameterize ops for execution later - * - * @param isConstOrPh If true: inputVar is either a constant or a placeholder - * @param inputVar Input variable (i.e., the X in (X, ...) -> op -> (forVariable,...)) - * @param forVariable Output variable (i.e., the Y in (inputVar, ...) -> op -> (Y,...)) + * Get the VarId from the specified name. The VarId should be in one or the other of the collections, + * and only one VarId with that name should exist */ - protected void addToExecInputs(boolean isConstOrPh, VarId inputVar, VarId forVariable) { - if (!subgraph.contains(forVariable.getVariable())) - return; //Not needed to calculate requested outputs, so no need to record it's inputs + protected static VarId lookup(String name, Collection varIds, Collection varIds2, boolean exceptionOnNotFound) { + VarId vid = varIds == null ? null : lookup(name, varIds, false); + if (vid == null && varIds2 != null) + vid = lookup(name, varIds2, false); - if (isConstOrPh) { - //Mark that outVar needs to use placeholder/constant (same regardless of frame/iter) - if (!execConstInputs.containsKey(forVariable.getVariable())) - execConstInputs.put(forVariable.getVariable(), new HashSet()); - execConstInputs.get(forVariable.getVariable()).add(inputVar.getVariable()); - } else { - //Mark that outVar needs this specific executedVar (i.e., specific frame/iteration) - //However, in the case of enter nodes, they are available for ALL iterations (used in loop conditions, for example) - Variable v = sameDiff.getVariables().get(inputVar.getVariable()); - boolean isEnter = sameDiff.getVariableOutputOp(v.getVariable().getVarName()) instanceof Enter; - - if(isEnter){ - VarId iter0 = forVariable; - if(iter0.getIteration() != 0){ - iter0 = newVarId(iter0.getVariable(), iter0.getFrame(), 0, forVariable.getParentFrame()); - } - - Variable var = sameDiff.getVariables().get(inputVar.getVariable()); - Enter e = (Enter) sameDiff.getOps().get(var.getOutputOfOp()).getOp(); - if(e.isConstant()){ - //For enter nodes that are constants, we want iteration 0 in all frames in the heirarchy - //For example, const -> Enter(a) -> Enter(b) -> op; in this case, the input to Op (at any frame/iteration) should should - // be the constant value - which is recorded as (frame="a",iter=0,parent=(frame="b",iter=0)) - iter0.setParentFrame(iter0.getParentFrame().clone()); - FrameIter toZero = iter0.getParentFrame(); - while(toZero != null){ - toZero.setIteration(0); - toZero = toZero.getParentFrame(); - } - } - - if(!execInputsAllIter.containsKey(iter0)) - execInputsAllIter.put(iter0, new HashSet()); - execInputsAllIter.get(iter0).add(inputVar); - } else { - //Most variables - if (!execInputs.containsKey(forVariable)) - execInputs.put(forVariable, new HashSet()); - execInputs.get(forVariable).add(inputVar); - } + if (vid == null && exceptionOnNotFound) { + throw new RuntimeException("Could not find VarId for input \"" + name + "\""); } + return vid; } - - protected static VarId lookup(String name, Collection varIds, boolean exceptionOnNotFound){ - for(VarId vid : varIds){ - if(vid.getVariable().equals(name)){ + /** + * Get the VarId from the specified name. The VarId should be in the collection, + * and only one VarId with that name should exist + */ + protected static VarId lookup(String name, Collection varIds, boolean exceptionOnNotFound) { + for (VarId vid : varIds) { + if (vid.getVariable().equals(name)) { return vid; } } - if(exceptionOnNotFound) { + if (exceptionOnNotFound) { throw new RuntimeException("Could not find VarId to input " + name); } return null; } - /* - VarId: identifies a variable in a specific frame and frame iteration - Used for 2 places: - (a) to identify variables that are available for execution - (b) to store results + /** + * VarId: identifies the value of a variable in a specific frame and frame iteration
+ * Note that frames can be nested - which generally represents nested loop situations.
+ * Used for 2 places:
+ * (a) to identify variables that are available for execution
+ * (b) to store results
*/ @Data @AllArgsConstructor @@ -954,13 +913,17 @@ public abstract class AbstractSession { return "VarId(\"" + variable + "\",\"" + frame + "\"," + iteration + ",parent=" + parentFrame + ")"; } + /** + * @return FrameIter corresponding to the VarId + */ public FrameIter toFrameIter() { return new FrameIter(frame, iteration, parentFrame); } } - /* - FrameIter: Identifies frame + iteration. Used mainly for for exit nodes + /** + * FrameIter: Identifies a frame + iteration (but not a specific op or variable).
+ * Note that frames can be nested - which generally represents nested loop situations. */ @Data @AllArgsConstructor @@ -970,13 +933,82 @@ public abstract class AbstractSession { private FrameIter parentFrame; @Override - public String toString(){ + public String toString() { return "(\"" + frame + "\"," + iteration + (parentFrame == null ? "" : ",parent=" + parentFrame.toString()) + ")"; } @Override - public FrameIter clone(){ + public FrameIter clone() { return new FrameIter(frame, iteration, (parentFrame == null ? null : parentFrame.clone())); } + + public VarId toVarId(String name) { + return new VarId(name, frame, iteration, parentFrame); + } } + + /** + * ExecType: Execution type, as used in ExecStep
+ * OP: Operation execution
+ * VARIABLE: Variable "execution", mainly used to trigger ops that depend on the variable
+ * CONSTANT: As per variable
+ * PLACEHOLDER: As per variable
+ * SWITCH_L and SWITCH_R: This is a bit of a hack to account for the fact that only one of + * the switch branches (left or right) will ever be available; without this, once the switch op is executed, we'll + * (incorrectly) conclude that *both* branches can be executed
+ * EXEC_START: Start of execution
+ * CONTROL_DEP: Control dependency for op. Used for TF import, due to its odd "constant depends on op in a frame" behaviour + */ + protected enum ExecType {OP, VARIABLE, CONSTANT, PLACEHOLDER, SWITCH_L, SWITCH_R, EXEC_START, CONTROL_DEP} + + ; + + /** + * ExecStep represents a single execution step, for a single op (or variable/constant etc) at a specific frame/iteration + */ + @Getter + @EqualsAndHashCode + protected static class ExecStep { + protected final ExecType type; + protected final String name; + protected final FrameIter frameIter; + + protected ExecStep(@NonNull ExecType execType, @NonNull String name, FrameIter frameIter) { + this.type = execType; + this.name = name; + this.frameIter = frameIter; + } + + protected VarId toVarId() { + return new VarId(name, frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame()); + } + + @Override + public String toString() { + return "ExecStep(" + type + ",name=\"" + name + "\"," + frameIter + ")"; + } + } + + /** + * Used in getting the next ExecStep that matches the specified (current) frame/iteration + */ + @Data + @AllArgsConstructor + @NoArgsConstructor + protected class ExecStepPredicate implements Predicate { + + protected String currentFrame; + protected int currentFrameIter; + protected FrameIter currParentFrame; + + @Override + public boolean test(ExecStep execStep) { + return currentFrame.equals(execStep.getFrameIter().getFrame()) && + currentFrameIter == execStep.getFrameIter().getIteration() && + (currParentFrame == null && execStep.getFrameIter().getParentFrame() == null || + currParentFrame.equals(execStep.getFrameIter().getParentFrame())); + } + } + + ; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DataTypesSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DataTypesSession.java deleted file mode 100644 index 56a6a406e..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DataTypesSession.java +++ /dev/null @@ -1,107 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.autodiff.samediff.internal; - -import lombok.AllArgsConstructor; -import lombok.Data; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.listeners.At; -import org.nd4j.autodiff.listeners.Listener; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.api.buffer.DataType; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Set; -import org.nd4j.linalg.dataset.api.MultiDataSet; - -/** - * Infer datatypes for all variables. - * Optionally update the datatypes of variables as we go - */ -public class DataTypesSession extends AbstractSession { - - protected boolean dynamicUpdate; - - /** - * @param sameDiff SameDiff instance - * @param dynamicUpdate If true: Dynamically update the datatypes as we go - */ - public DataTypesSession(SameDiff sameDiff, boolean dynamicUpdate) { - super(sameDiff); - this.dynamicUpdate = dynamicUpdate; - } - - @Override - public DataType getConstantOrVariable(String variableName) { - //Variables and constants should always have datatype available - DataType dt = sameDiff.getVariable(variableName).dataType(); - Preconditions.checkNotNull(dt, "No datatype available for variable %s", variableName); - return dt; - } - - @Override - public DataTypeCalc getAndParameterizeOp(String opName, FrameIter frameIter, Set inputs, Set allIterInputs, Set constAndPhInputs, Map placeholderValues) { - DifferentialFunction df = sameDiff.getOpById(opName); - List inputDataTypes = new ArrayList<>(); - for(SDVariable v : df.args()){ - DataType dt = v.dataType(); - if(dt != null){ - inputDataTypes.add(dt); - } else { - String s = v.getVarName(); - for(VarId vid : inputs){ - if(vid.getVariable().equals(s)){ - DataType dt2 = nodeOutputs.get(vid); - Preconditions.checkNotNull(dt2, "No datatype for %s", vid); - inputDataTypes.add(dt2); - } - } - } - } - return new DataTypeCalc(df, inputDataTypes); - } - - @Override - public DataType[] getOutputs(DataTypeCalc op, FrameIter outputFrameIter, Set inputs, Set allIterInputs, - Set constAndPhInputs, List listeners, At at, MultiDataSet batch) { - List outTypes = op.getFn().calculateOutputDataTypes(op.getInputTypes()); - - if(dynamicUpdate) { - SDVariable[] fnOutputs = op.getFn().outputVariables(); - for( int i=0; i inputTypes; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DependencyList.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DependencyList.java new file mode 100644 index 000000000..c718bf152 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DependencyList.java @@ -0,0 +1,20 @@ +package org.nd4j.autodiff.samediff.internal; + +import lombok.AllArgsConstructor; +import lombok.Data; +import org.nd4j.linalg.primitives.Pair; + +import java.util.List; + +/** + * A list of dependencies, used in {@link AbstractDependencyTracker} + * + * @author Alex Black + */ +@Data +@AllArgsConstructor +public class DependencyList { + private T dependencyFor; + private List dependencies; + private List> orDependencies; +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DependencyTracker.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DependencyTracker.java new file mode 100644 index 000000000..d172221ee --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DependencyTracker.java @@ -0,0 +1,38 @@ +package org.nd4j.autodiff.samediff.internal; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.primitives.Pair; + +import java.util.*; + +/** + * Dependenci tracker. See {@link AbstractDependencyTracker} for details + * + * @param For a dependency X -> Y, Y has type T + * @param For a dependency X -> Y, X has type D + */ +@Slf4j +public class DependencyTracker extends AbstractDependencyTracker { + + @Override + protected Map newTMap() { + return new HashMap<>(); + } + + @Override + protected Set newTSet() { + return new HashSet<>(); + } + + @Override + protected String toStringT(T t) { + return t.toString(); + } + + @Override + protected String toStringD(D d) { + return d.toString(); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/IdentityDependencyTracker.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/IdentityDependencyTracker.java new file mode 100644 index 000000000..5e7e46c80 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/IdentityDependencyTracker.java @@ -0,0 +1,44 @@ +package org.nd4j.autodiff.samediff.internal; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.primitives.Pair; + +import java.util.*; + +/** + * Object dependency tracker, using object identity (not object equality) for the Ys (of type T)
+ * See {@link AbstractDependencyTracker} for more details + * + * @author Alex Black + */ +@Slf4j +public class IdentityDependencyTracker extends AbstractDependencyTracker { + + @Override + protected Map newTMap() { + return new IdentityHashMap<>(); + } + + @Override + protected Set newTSet() { + return Collections.newSetFromMap(new IdentityHashMap()); + } + + @Override + protected String toStringT(T t) { + if(t instanceof INDArray){ + INDArray i = (INDArray)t; + return System.identityHashCode(t) + " - id=" + i.getId() + ", " + i.shapeInfoToString(); + } else { + return System.identityHashCode(t) + " - " + t.toString(); + } + } + + @Override + protected String toStringD(D d) { + return d.toString(); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index e16dad580..354f537a8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -16,7 +16,7 @@ package org.nd4j.autodiff.samediff.internal; -import lombok.NonNull; +import lombok.*; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.At; @@ -24,15 +24,17 @@ import org.nd4j.autodiff.listeners.Listener; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.internal.memory.ArrayCloseMemoryMgr; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.*; import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner; -import org.nd4j.linalg.api.ops.impl.controlflow.If; -import org.nd4j.linalg.api.ops.impl.controlflow.While; import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; +import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; +import org.nd4j.linalg.api.ops.impl.shape.Concat; +import org.nd4j.linalg.api.ops.impl.shape.Stack; import org.nd4j.linalg.api.ops.impl.shape.tensorops.*; import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; @@ -48,36 +50,92 @@ import org.nd4j.linalg.util.ArrayUtil; import java.util.*; /** - * InferenceSession: Performs inference (forward pass) on a SameDiff instance to get the outputs of the requested nodes. - * Dynamically (in AbstractSession) calculates the required subgraph to execute to get the required outputs. + * InferenceSession: Performs inference (forward pass) on a SameDiff instance to get the outputs of the requested nodes.
+ * Dynamically (in AbstractSession) calculates the required subgraph to execute to get the required outputs.
+ * Note that while AbstractSession handles the graph structure component, InferenceSession handles only op execution + * and memory management
+ *
+ * For INDArray memory management - i.e., tracking and releasing memory manually, as soon as possible, to + * minimize memory use - this is implemented using a {@link SessionMemMgr} instance (for allocations/deallocations) and + * also {@link IdentityDependencyTracker} to track where arrays are actually used. The IdentityDependencyTracker tells + * us when the array is no longer needed (i.e., has been "fully consumed" by all ops depending on it) accounting for the + * fact that some operations, such as identity, enter, exit, etc, are "zero copy" for performance reasons. * * @author Alex Black */ @Slf4j -public class InferenceSession extends AbstractSession { +public class InferenceSession extends AbstractSession { private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\n" + "Alternatively, arrays defined in a workspace must be replaced after the workspace has been closed."; + protected static final String KERAS_TRAIN_TEST = "keras_learning_phase"; + + @Getter + @Setter + private SessionMemMgr mmgr; //Used for allocating and deallocating memory + /** + * Array use tracker: What needs to happen before the array can be closed/released? + * As the name suggests, the INDArrays are tracked using qbject identity, not equality + */ + @Getter + @Setter + private IdentityDependencyTracker arrayUseTracker = new IdentityDependencyTracker<>(); + + public InferenceSession(@NonNull SameDiff sameDiff) { super(sameDiff); + + mmgr = new ArrayCloseMemoryMgr(); //TODO replace this with new (planned) array reuse memory manager } @Override - protected Map preprocessPlaceholders(Map placeholders){ - //Handle casting of the input array automatically. - //The idea here is to avoid unexpected errors if the user (for example) tries to perform inference with a double - // array for a float placeholder - if(placeholders == null || placeholders.isEmpty()){ + protected Map preprocessPlaceholders(Map placeholders, At at) { + arrayUseTracker.clear(); + + //We'll also use this method as a "pre execution" hook-in, to mark variables as something we should never deallocate + //This occurs by never marking these "ConstantDep" and "VariableDep" instances as satisfied, so there's always + // an unsatisfied dependency for them in the array use tracker + //TODO we shouldn't be clearing this on every single iteration, in 99.5% of cases variables will be same as last iteration... + for (SDVariable v : sameDiff.variables()) { + if (v.getVariableType() == VariableType.CONSTANT) { + arrayUseTracker.addDependency(v.getArr(), new ConstantDep(v.getVarName())); + } else if (v.getVariableType() == VariableType.VARIABLE) { + arrayUseTracker.addDependency(v.getArr(), new VariableDep(v.getVarName())); + } + } + + //Workaround for some TF/Keras based models that require explicit train/test as a placeholder + boolean kerasWorkaround = false; + List phs = sameDiff.inputs(); + if (phs != null && !phs.isEmpty()) { + for (String s : phs) { + if (s.endsWith(KERAS_TRAIN_TEST) && !placeholders.containsKey(s)) { + // The behaviour of some Keras layers (like GRU) differs depending on whether the model is training. + // We provide this value directly, unless the user has provided this manually + INDArray scalar = mmgr.allocate(false, DataType.BOOL).assign(at.operation().isTrainingPhase()); + placeholders = new HashMap<>(placeholders); //Array might be singleton, or otherwise unmodifiable + placeholders.put(s, scalar); + kerasWorkaround = true; + } + } + } + + + if (placeholders == null || placeholders.isEmpty()) { return placeholders; } - Map out = new HashMap<>(); - for(Map.Entry e : placeholders.entrySet()){ + //Handle casting of the input array automatically. + //The idea here is to avoid unexpected errors if the user (for example) tries to perform inference with a double + // array for a float placeholder + //TODO eventually we might have ops that support multiple input types, and hence won't need this casting + Map out = new HashMap<>(); + for (Map.Entry e : placeholders.entrySet()) { Preconditions.checkState(sameDiff.hasVariable(e.getKey()), "Invalid placeholder passed for execution: " + "No variable/placeholder with name %s exists", e.getKey()); INDArray arr = e.getValue(); //First: check workspaces - if(arr.isAttached()){ + if (arr.isAttached()) { MemoryWorkspace ws = arr.data() == null ? null : arr.data().getParentWorkspace(); if (ws != null && ws.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) { if (!ws.isScopeActive()) { @@ -96,89 +154,234 @@ public class InferenceSession extends AbstractSession opInputs, Set allIterInputs, - Set constAndPhInputs, List listeners, At at, MultiDataSet batch) { - if(listeners != null && listeners.size() > 0){ - SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName()); - for(Listener l : listeners){ - if(l.isActive(at.operation())) + protected Map postProcessOutput(Map output) { + + //For any queued (not yet processed) ops - mark them as satisfied, so we can deallocate any arrays + // that are waiting on them + if (dt.hasNewAllSatisfied()) { + List execSteps = dt.getNewAllSatisfiedList(); + for (ExecStep es : execSteps) { + if (es.getType() == ExecType.OP) { + OpDep od = new OpDep(es.getName(), es.getFrameIter().getFrame(), es.getFrameIter().getIteration(), es.getFrameIter().getParentFrame()); + arrayUseTracker.markSatisfied(od, true); + } + } + } + + //Also mark "end of execution" for array dependency tracker. Mainly used for TensorArray arrays at present. + //TODO Optimize for reduced memory for some TensorArray operations - i.e., close/deallocate earlier + arrayUseTracker.markSatisfied(new ExecDoneDep(), true); + if (arrayUseTracker.hasNewAllSatisfied()) { + List l = arrayUseTracker.getNewAllSatisfiedList(); + for (INDArray arr : l) { + mmgr.release(arr); + } + } + + return output; + } + + @Override + public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, + Set constAndPhInputs, List listeners, At at, MultiDataSet batch, Set allReqVariables) { + if (listeners != null && listeners.size() > 0) { + SameDiffOp sdOp = sameDiff.getOps().get(op.getOp().getOwnName()); + for (Listener l : listeners) { + if (l.isActive(at.operation())) l.preOpExecution(sameDiff, at, sdOp); } } - INDArray[] out = getOutputsHelper(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs); - if(listeners != null && listeners.size() > 0){ - SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName()); + INDArray[] out = doExec(op.getOp(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs); + op.getOp().clearArrays(); - Map namedOutsBuilder = new HashMap<>(); + if (log.isTraceEnabled()) { + StringBuilder sb = new StringBuilder(); + sb.append(op.getName()).append(" - ").append(outputFrameIter).append(" outputs: "); + List opOutNames = op.getOutputsOfOp(); + for (int i = 0; i < out.length; i++) { + if (i > 0) + sb.append(", "); + sb.append("(").append(i).append(" - ").append(opOutNames.get(i)).append(" = ").append( + out[i] == null ? null : out[i].getId()).append(")"); + } + log.trace(sb.toString()); + } - for(int i = 0 ; i < out.length ; i++) - namedOutsBuilder.put(sdOp.outputsOfOp.get(i), out[i]); + //Call listeners, before we (maybe) deallocate input arrays + if (listeners != null && listeners.size() > 0) { + Map namedOuts = null; - Map namedOuts = Collections.unmodifiableMap(namedOutsBuilder); + for (Listener l : listeners) { + if (l.isActive(at.operation())) { + //Lazily create map, only if required + if (namedOuts == null) { + Map namedOutsBuilder = new HashMap<>(); - for(Listener l : listeners){ - if(l.isActive(at.operation())) { - l.opExecution(sameDiff, at, batch, sdOp, out); + for (int i = 0; i < out.length; i++) + namedOutsBuilder.put(op.outputsOfOp.get(i), out[i]); + namedOuts = Collections.unmodifiableMap(namedOutsBuilder); + } - for(String varName : namedOuts.keySet()){ - l.activationAvailable(sameDiff, at, batch, sdOp, varName, namedOuts.get(varName)); + + l.opExecution(sameDiff, at, batch, op, out); + + for (String varName : namedOuts.keySet()) { + l.activationAvailable(sameDiff, at, batch, op, varName, namedOuts.get(varName)); } } } } + + + //Record array uses for memory management/deallocation + SameDiffOp o = sameDiff.getOps().get(op.getName()); + List outVarNames = o.getOutputsOfOp(); + for (int i = 0; i < out.length; i++) { + if (out[i] == null && o.getOp() instanceof Switch) + continue; //Switch case: we only ever get one of 2 outputs, other is null (branch not executed) + + String name = outVarNames.get(i); + Variable v = sameDiff.getVariables().get(name); + List inputsForOps = v.getInputsForOp(); + if (inputsForOps != null) { + for (String opName : inputsForOps) { + //Only add dependencies if we actually need the op this feeds into, otherwise the dependency + // will will never be marked as satisfied + if (!subgraphOps.contains(opName)) + continue; + + SameDiffOp forOp = sameDiff.getOps().get(opName); + + //TODO do switch or merge need special handling also? + if (forOp.getOp() instanceof Enter) { + Enter e = (Enter) forOp.getOp(); + if (e.isConstant()) { + /* + Contant enter case: Need to keep this array around for the entire duration of the frame, including + any nested frames, and all iterations. + Unfortunately, we don't know exactly when we're done with a frame for good + This isn't a great solution, but other possibilities (frame close, trying to detect all exit ops, + detecting return to parent frame, etc all fail in certain circumstances, such as due to control dependencies + on variables). + */ + Dep d = new ExecDoneDep(); + arrayUseTracker.addDependency(out[i], d); + } else { + Dep d = new OpDep(opName, e.getFrameName(), 0, outputFrameIter); + arrayUseTracker.addDependency(out[i], d); //Op defined by "d" needs to be executed before specified array can be closed + } + } else if (forOp.getOp() instanceof NextIteration) { + //The array is needed by the NEXT iteration op, not the current one + Dep d = new OpDep(opName, outputFrameIter.getFrame(), outputFrameIter.getIteration() + 1, outputFrameIter.getParentFrame()); + arrayUseTracker.addDependency(out[i], d); + } else if (forOp.getOp() instanceof Exit) { + //The array is needed at the EXIT frame (i.e., parent frame), not the inner/just executed one + FrameIter fi = outputFrameIter.getParentFrame(); + Dep d = new OpDep(opName, fi.getFrame(), fi.getIteration(), fi.getParentFrame()); + arrayUseTracker.addDependency(out[i], d); //Op defined by "d" needs to be executed before specified array can be closed + } else { + //All other ops... + Dep d = new OpDep(opName, outputFrameIter.getFrame(), outputFrameIter.getIteration(), outputFrameIter.getParentFrame()); + arrayUseTracker.addDependency(out[i], d); //Op defined by "d" needs to be executed before specified array can be closed + } + } + } + + if (OUTER_FRAME.equals(outputFrameIter.getFrame()) && allReqVariables.contains(name)) { + //This variable is an output, record that in the array use tracker, so we don't deallocate it + arrayUseTracker.addDependency(out[i], new ReqOutputDep(name)); + } else if ((inputsForOps == null || inputsForOps.isEmpty()) && !arrayUseTracker.hasDependency(out[i])) { + //This particular array is not actually needed anywhere, so we can deallocate in immediately + //Possibly only a control dependency, or only one of the outputs of a multi-output op is used + if (log.isTraceEnabled()) { + log.trace("Found array id {} (output of {}) not required anywhere, deallocating", out[i].getId(), o.getName()); + } + mmgr.release(out[i]); + } + } + + //Mark current op dependency as satisfied... + Dep d = new OpDep(op.getName(), outputFrameIter.getFrame(), outputFrameIter.getIteration(), outputFrameIter.getParentFrame()); + arrayUseTracker.markSatisfied(d, true); + + + //Close any no longer required arrays + if (arrayUseTracker.hasNewAllSatisfied()) { + List canClose = arrayUseTracker.getNewAllSatisfiedList(); + for (INDArray arr : canClose) { + if (log.isTraceEnabled()) { + log.trace("Closing array... id={}, {}", arr.getId(), arr.shapeInfoToString()); + } + mmgr.release(arr); + } + } + return out; } - public INDArray[] getOutputsHelper(DifferentialFunction op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, - Set constAndPhInputs){ + public INDArray[] doExec(DifferentialFunction op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, + Set constAndPhInputs) { int totalInputs = (opInputs == null ? 0 : opInputs.size()) + (constAndPhInputs == null ? 0 : constAndPhInputs.size()) + (allIterInputs == null ? 0 : allIterInputs.size()); boolean constPhInput = (opInputs == null || opInputs.size() == 0) && (allIterInputs == null || allIterInputs.size() == 0); - if(op instanceof Identity ) { + if (op instanceof Identity) { Identity i = (Identity) op; String[] argNames = i.argNames(); - Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in identity op, got %s", argNames); - VarId vid = newVarId(argNames[0], outputFrameIter); - return new INDArray[]{nodeOutputs.get(vid)}; + Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in identity op, got %s", (Object) argNames); + VarId vid = outputFrameIter.toVarId(argNames[0]); - } else if(op instanceof Switch) { + INDArray orig = nodeOutputs.get(vid); + return new INDArray[]{orig}; + } else if (op instanceof Switch) { Switch s = (Switch) op; String[] argNames = s.argNames(); //Order: input, boolean array - VarId vidPredicate = newVarId(argNames[1], outputFrameIter); + VarId vidPredicate = outputFrameIter.toVarId(argNames[1]); INDArray predicate = this.nodeOutputs.get(vidPredicate); Preconditions.checkState(predicate.isScalar() && predicate.dataType() == DataType.BOOL, "Expected boolean predicate: got %ndSInfo", predicate); - VarId vid = newVarId(argNames[0], outputFrameIter); + VarId vid = outputFrameIter.toVarId(argNames[0]); if (predicate.getDouble(0) == 0.0) { return new INDArray[]{this.nodeOutputs.get(vid), null}; } else { return new INDArray[]{null, this.nodeOutputs.get(vid)}; } - } else if(op instanceof Enter) { + } else if (op instanceof Enter) { //Enter op: forwards input to specified execution frame - Enter e = (Enter)op; + Enter e = (Enter) op; String[] input = e.argNames(); - Preconditions.checkState(input.length == 1, "Expected only 1 arg name for enter op: got %s", input); + Preconditions.checkState(input.length == 1, "Expected only 1 arg name for enter op: got %s", (Object) input); Preconditions.checkState(totalInputs == 1, "Expected exactly 1 op input for Enter op \"%s\", got %s+%s", e.getOwnName(), opInputs, constAndPhInputs); VarId inputVarId; - if(constPhInput) { + if (constPhInput) { //Constant or placeholder inputVarId = new VarId(constAndPhInputs.iterator().next(), OUTER_FRAME, 0, null); - } else if(allIterInputs != null && allIterInputs.size() > 0){ + } else if (allIterInputs != null && allIterInputs.size() > 0) { inputVarId = allIterInputs.iterator().next(); } else { inputVarId = opInputs.iterator().next(); @@ -187,332 +390,356 @@ public class InferenceSession extends AbstractSession 0){ + } else if (allIterInputs != null && allIterInputs.size() > 0) { inputVarId = allIterInputs.iterator().next(); } else { inputVarId = opInputs.iterator().next(); } INDArray exitInput = this.nodeOutputs.get(inputVarId); return new INDArray[]{exitInput}; - } else if(op instanceof NextIteration){ + } else if (op instanceof NextIteration) { //NextIteration op: forwards its single input to the output of the current frame, but increments the iteration number Preconditions.checkState(totalInputs == 1, "Expected exactly 1 op input for NextIteration: got %s+%s", opInputs, constAndPhInputs); VarId in = (allIterInputs != null && !allIterInputs.isEmpty() ? allIterInputs.iterator().next() : opInputs.iterator().next()); Preconditions.checkState(outputFrameIter.getFrame().equals(in.getFrame()), "Expected same frame for NextIteration input vs. output:" + " got input %s, output %s", in, outputFrameIter); - Preconditions.checkState(outputFrameIter.getIteration() == in.getIteration()+1, "Expected output iteration for NextIteration output to" + + Preconditions.checkState(outputFrameIter.getIteration() == in.getIteration() + 1, "Expected output iteration for NextIteration output to" + " be 1 larger than the input iteration. Input: %s, output %s", in, outputFrameIter); INDArray inArr = this.nodeOutputs.get(in); + if (inArr == null) { + Preconditions.throwStateEx("Could not find array for NextIteration operation %s with output %s (frame=%s, iteration=%s)", + op.getOwnName(), sameDiff.getOps().get(op.getOwnName()).getOutputsOfOp().get(0), outputFrameIter.getFrame(), outputFrameIter.getIteration()); + } return new INDArray[]{inArr}; - } else if(op instanceof If) { - If i = (If) op; - String[] argNames = i.argNames(); //Order should be: [boolean], true, false - - - throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName()); - } else if(op instanceof Merge) { - //Merge avairable for forward pass when any of its inputs are available. When multiple are available, behaviour + } else if (op instanceof Merge) { + //Merge available for forward pass when any of its inputs are available. When multiple are available, behaviour // is undefined Merge m = (Merge) op; String[] in = sameDiff.getInputsForOp(op); for (String s : in) { - VarId vid = newVarId(s, outputFrameIter); + VarId vid = outputFrameIter.toVarId(s); if (nodeOutputs.containsKey(vid)) { log.trace("Returning input \"{}\" for merge node \"{}\"", m.getOwnName(), s); - return new INDArray[]{nodeOutputs.get(vid)}; + INDArray arr = nodeOutputs.get(vid); + Preconditions.checkState(arr != null, "Could not find output array for %s", vid); + return new INDArray[]{arr}; } } throw new IllegalStateException("Merge node " + m.getOwnName() + " has no available inputs (all inputs: " + Arrays.toString(in) + ") - should not be executed at this point"); - } else if(op instanceof LoopCond) { + } else if (op instanceof LoopCond) { //LoopCond just forwards scalar boolean to output LoopCond lc = (LoopCond) op; String[] argNames = lc.argNames(); - Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in LoopCond op, got %s", argNames); - VarId vid = newVarId(argNames[0], outputFrameIter); + Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in LoopCond op, got %s", (Object) argNames); + VarId vid = outputFrameIter.toVarId(argNames[0]); INDArray arr = nodeOutputs.get(vid); Preconditions.checkNotNull(arr, "Input to LoopCond op must not be null"); Preconditions.checkState(arr.isScalar() && arr.dataType() == DataType.BOOL, "LoopCond input must be a scalar boolean, got %ndShape"); return new INDArray[]{arr}; - } else if(op instanceof BaseTensorOp) { + } else if (op instanceof BaseTensorOp) { //TensorOps - special cases... - if (op instanceof TensorArray) { - //Create a TensorArray - VarId vid = newVarId(op.outputVariable().getVarName(), outputFrameIter); - Preconditions.checkState(!tensorArrays.containsKey(vid), "TensorArray already exists for %s when executing TensorArrayV3", vid); - tensorArrays.put(vid, new ArrayList()); - - // Note that TensorArray has 2 outputs - a 'dummy' SDVariable that represents it, and a second output (return a scalar 0.0) - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - return new INDArray[]{Nd4j.scalar(true), Nd4j.scalar(0.0f)}; - } - } else if (op instanceof TensorArrayRead) { - //Do lookup and return - //Input 0 is the TensorArray (or dummy variable that represents it). Sometimes (for import) this can be like (TensorArray -> Enter -> TensorArrayRead) - //Input 1 is the index - SDVariable idxSDV = op.arg(1); - INDArray idxArr = getArray(idxSDV, opInputs, allIterInputs); - Preconditions.checkState(idxArr.isScalar(), "TensorArrayRead input argument 1 should be scalar - has shape %ndShape", idxArr); - int i = idxArr.getInt(0); - - SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - - //Work out the frame/iteration: - VarId v = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); - if(v == null && allIterInputs != null){ - v = lookup(inTensorArray.getVarName(), allIterInputs, false); - } - - Preconditions.checkState(v != null, "Could not find input %s", inTensorArray.getVarName()); - - while(sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter){ - //Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayRead - //TODO also TensorArrayWrite, scatter, etc?? - inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg(); - v = newVarId(inTensorArray.getVarName(), v.getParentFrame()); - } - - List list = getTensorArrays().get(v); - Preconditions.checkState(list != null, "Could not find TensorList for %s", v); - Preconditions.checkState(list.size() > i, "Cannot get index %s from TensorList of size %s (array not present?) - VarId=%s", i, list.size(), v); - - INDArray out = list.get(i); - return new INDArray[]{out}; - } else if (op instanceof TensorArrayWrite) { - //TensorArrayWrite - also has a scalar 0.0 that it returns... - - SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - //Work out the varid (frame/iteration) of the tensor array: - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); - if(tArr == null && allIterInputs != null){ - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); - } - - Preconditions.checkState(tArr != null, "Could not find input %s", inTensorArray.getVarName()); - - while(sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter){ - //Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayWrite - //TODO also TensorArrayScatter, etc?? - inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg(); - tArr = newVarId(inTensorArray.getVarName(), tArr.getParentFrame()); - } - - //Input 0 is the TensorArray (or dummy variable that represents it) - but sometimes Enter, in TensorArray -> Enter -> TensorARrayRead - //Input 1 is the index - //Input 2 is the value to write - - String idxName = op.arg(1).getVarName(); - SDVariable idxSDV = sameDiff.getVariable(idxName); - INDArray idxArr = getArray(idxSDV, opInputs, allIterInputs); - Preconditions.checkState(idxArr.isScalar(), "Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", idxArr); - int idx = idxArr.getInt(0); - - String inName = op.arg(2).getVarName(); - SDVariable inSDV = sameDiff.getVariable(inName); - INDArray arr = getArray(inSDV, opInputs, allIterInputs); - Preconditions.checkState(arr != null, "Could not find array for %s", inName); - - Preconditions.checkState(tensorArrays.containsKey(tArr), "Tensor array does not exist for %s", tArr); - //TODO is this always safe to insert by index for all execution orders? - List l = tensorArrays.get(tArr); //.set(idx, arr); - while (l.size() <= idx) { - //Can't use set(int, E) if index >= size - l.add(null); - } - l.set(idx, arr); - - //Return dummy array - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - return new INDArray[]{Nd4j.scalar(0.0f)}; - } - } else if (op instanceof TensorArraySize) { - //Index 0 is the TensorArray (or dummy variable that represents it) - SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - //Work out the varid (frame/iteration) of the tensor array: - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); - if(tArr == null && allIterInputs != null){ - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); - } - List l = tensorArrays.get(tArr); - Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - return new INDArray[]{Nd4j.scalar(DataType.INT, l.size())}; - } - } else if (op instanceof TensorArrayConcat) { - SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); - if(tArr == null && allIterInputs != null){ - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); - } - List l = tensorArrays.get(tArr); - //TODO - empty checks. But is size 0 OK? - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - INDArray concat = Nd4j.concat(0, l.toArray(new INDArray[l.size()])); - return new INDArray[]{concat}; - } - } else if (op instanceof TensorArrayGather) { - //Input 0: the TensorArray - //Input 1: the indices (1d integer vector) - - SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); - if(tArr == null && allIterInputs != null){ - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); - } - List l = tensorArrays.get(tArr); - Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); - - String indicesName = op.arg(1).getVarName(); - SDVariable indicesSDV = sameDiff.getVariable(indicesName); - INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs); - Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", idxArr, indicesName); - Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayGather should be an integer type, got %s for array %s", idxArr.dataType(), indicesName); - - int[] idxArrInt = idxArr.toIntVector(); - - //Edge case: -1 means "all" - ArrayList newList = new ArrayList<>(); - if(idxArrInt.length == 1 && idxArrInt[0] == -1){ - newList.addAll(l); - } else { - for (int id : idxArrInt) { - Preconditions.checkState(id >=0,"Index for TensorArrayGather must be >= 0, got %s", id); - newList.add(l.get(id)); - } - } - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - INDArray out = Nd4j.pile(newList); - return new INDArray[]{out}; - } - } else if (op instanceof TensorArrayScatter) { - //Scatter values from a rank (N+1)d tensor into specific indices of the TensorArray - //Input 0: the TensorArray - //Input 1: the indices (1d integer vector) - //Input 2: The values to scatter - - SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - TensorArray ta = (TensorArray) sameDiff.getVariableOutputOp(inTensorArray.getVarName()); - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); - if(tArr == null && allIterInputs != null){ - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); - } - List l = tensorArrays.get(tArr); - Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); - - String indicesName = op.arg(1).getVarName(); - SDVariable indicesSDV = sameDiff.getVariable(indicesName); - INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs); - Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", idxArr, indicesName); - Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", idxArr.dataType(), indicesName); - int[] idxs = idxArr.toIntVector(); - - String valuesName = op.arg(2).getVarName(); - SDVariable valuesSDV = sameDiff.getVariable(valuesName); - INDArray valuesArr = getArray(valuesSDV, opInputs, allIterInputs); - - while (l.size() <= idxs.length) { //Can't use set(int, E) if index >= size - l.add(null); - } - - //Edge case: idxs being [-1] means "all sub arrays" (i.e., "unstack" case) - if(idxs.length == 1 && idxs[0] == -1){ - idxs = ArrayUtil.range(0, (int)valuesArr.size(0)); - } - - INDArrayIndex[] idx = ArrayUtil.nTimes(valuesArr.rank(), NDArrayIndex.all(), INDArrayIndex.class); - for (int i = 0; i < idxs.length; i++) { - idx[0] = NDArrayIndex.point(i); - INDArray get = valuesArr.get(idx).dup(); - int outIdx = idxs[i]; - if(valuesArr.rank() == 2 && get.rank() == 2){ - //Workaround for: https://github.com/deeplearning4j/deeplearning4j/issues/7092 - get = get.reshape(get.length()); - } - if(valuesArr.rank() == 1 && get.rank() > 0){ - get = get.reshape(new long[0]); - } - l.set(outIdx, get); - } - - //Return dummy array - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - return new INDArray[]{Nd4j.scalar(0.0f)}; - } - } else if (op instanceof TensorArraySplit) { - //Split values from a rank (N+1)d tensor into sequential indices of the TensorArray - //For example, orig=[8,2] sizearray with split (4,4) means TensorArray[0] = orig[0:4,:] and TensorArray[1] = orig[4:8,:] - //Input 0: the TensorArray - //Input 1: The values to split - //Input 2: the size of each split (1d integer vector) - - SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); - if(tArr == null && allIterInputs != null){ - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); - } - List l = tensorArrays.get(tArr); - Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); - - String splitName = op.arg(1).getVarName(); - INDArray splitArr = getArray(sameDiff.getVariable(splitName), opInputs, allIterInputs); - - - String sizeName = op.arg(2).getVarName(); - SDVariable sizeSDV = sameDiff.getVariable(sizeName); - INDArray sizeArr = getArray(sizeSDV, opInputs, allIterInputs); - Preconditions.checkState(sizeArr.isVector(), "Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", sizeArr, sizeName); - Preconditions.checkState(sizeArr.dataType().isIntType(), "Indices variable for TensorArraySplit should be an integer type, got %s for array %s", sizeArr.dataType(), sizeName); - int[] sizes = sizeArr.toIntVector(); - - while (l.size() <= sizes.length) { //Can't use set(int, E) if index >= size - l.add(null); - } - - INDArrayIndex[] idx = ArrayUtil.nTimes(splitArr.rank(), NDArrayIndex.all(), INDArrayIndex.class); - int soFar = 0; - for (int i = 0; i < sizes.length; i++) { - idx[0] = NDArrayIndex.interval(soFar, soFar + sizes[i]); - INDArray sub = splitArr.get(idx).dup(); - l.set(i, sub); - soFar += sizes[i]; - } - //Return dummy array - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - return new INDArray[]{Nd4j.scalar(0.0f)}; - } - } else { - throw new IllegalStateException("Execution support not yet implemented for: " + op.getClass().getName()); - } - } else if(op instanceof GradientBackwardsMarker){ - return new INDArray[]{Nd4j.scalar(1.0f)}; - } else if(op instanceof CustomOp){ - CustomOp c = (CustomOp)op; - Nd4j.getExecutioner().exec(c); + return getOutputsHelperTensorArrayOps(op, outputFrameIter, opInputs, allIterInputs); + } else if (op instanceof GradientBackwardsMarker) { + INDArray out = mmgr.allocate(false, DataType.FLOAT).assign(1.0f); + return new INDArray[]{out}; + } else if (op instanceof ExternalErrorsFunction) { + ExternalErrorsFunction fn = (ExternalErrorsFunction) op; + String n = fn.getGradPlaceholderName(); + INDArray arr = nodeOutputs.get(new VarId(n, OUTER_FRAME, 0, null)); + Preconditions.checkState(arr != null, "Could not find external errors placeholder array: %s", arr); + INDArray out = mmgr.allocate(false, arr.dataType(), arr.shape()); + out.assign(arr); + return new INDArray[]{out}; + } else if (op instanceof CustomOp) { + CustomOp c = (CustomOp) op; + Nd4j.exec(c); return c.outputArguments(); - } else if(op instanceof Op) { + } else if (op instanceof Op) { Op o = (Op) op; - Nd4j.getExecutioner().exec(o); + Nd4j.exec(o); return new INDArray[]{o.z()}; } else { throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName()); } } + /** + * Forward pass for TensorArray ops + */ + public INDArray[] getOutputsHelperTensorArrayOps(DifferentialFunction op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs) { + /* + TODO: TensorArray memory management note: For now, we'll close any INDArrays stored in the TensorArray at the end of + graph execution. This uses more memory than necessary for an earlier close strategy, but simplifies memory management. + This should be revisited and optimized later + */ + + if (op instanceof TensorArray) { + //Create a TensorArray + VarId vid = outputFrameIter.toVarId(op.outputVariable().getVarName()); + Preconditions.checkState(!tensorArrays.containsKey(vid), "TensorArray already exists for %s when executing TensorArrayV3", vid); + tensorArrays.put(vid, new ArrayList()); + + // Note that TensorArray has 2 outputs - a 'dummy' SDVariable that represents it, and a second output (return a scalar 0.0) + INDArray dummy = mmgr.allocate(false, DataType.BOOL).assign(true); + INDArray scalar = mmgr.allocate(false, DataType.FLOAT).assign(0.0); + return new INDArray[]{dummy, scalar}; + } else if (op instanceof TensorArrayRead) { + //Do lookup and return + //Input 0 is the TensorArray (or dummy variable that represents it). Sometimes (for import) this can be like (TensorArray -> Enter -> TensorArrayRead) + //Input 1 is the index + SDVariable idxSDV = op.arg(1); + INDArray idxArr = getArray(idxSDV, opInputs, allIterInputs); + Preconditions.checkState(idxArr.isScalar(), "TensorArrayRead input argument 1 should be scalar - has shape %ndShape", idxArr); + int i = idxArr.getInt(0); + + SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array + + //Work out the frame/iteration: + VarId v = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); + if (v == null && allIterInputs != null) { + v = lookup(inTensorArray.getVarName(), allIterInputs, false); + } + + Preconditions.checkState(v != null, "Could not find input %s", inTensorArray.getVarName()); + + while (sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter) { + //Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayRead + //TODO also TensorArrayWrite, scatter, etc?? + inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg(); + v = v.getParentFrame().toVarId(inTensorArray.getVarName()); + } + + List list = getTensorArrays().get(v); + Preconditions.checkState(list != null, "Could not find TensorList for %s", v); + Preconditions.checkState(list.size() > i, "Cannot get index %s from TensorList of size %s (array not present?) - VarId=%s", i, list.size(), v); + + INDArray out = list.get(i); + return new INDArray[]{out}; + } else if (op instanceof TensorArrayWrite) { + //TensorArrayWrite - also has a scalar 0.0 that it returns... + SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array + //Work out the varid (frame/iteration) of the tensor array: + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); + if (tArr == null && allIterInputs != null) { + tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); + } + + Preconditions.checkState(tArr != null, "Could not find input %s", inTensorArray.getVarName()); + + while (sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter) { + //Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayWrite + //TODO also TensorArrayScatter, etc?? + inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg(); + tArr = tArr.getParentFrame().toVarId(inTensorArray.getVarName()); + } + + //Input 0 is the TensorArray (or dummy variable that represents it) - but sometimes Enter, in TensorArray -> Enter -> TensorARrayRead + //Input 1 is the index + //Input 2 is the value to write + + String idxName = op.arg(1).getVarName(); + SDVariable idxSDV = sameDiff.getVariable(idxName); + INDArray idxArr = getArray(idxSDV, opInputs, allIterInputs); + Preconditions.checkState(idxArr.isScalar(), "Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", idxArr); + int idx = idxArr.getInt(0); + + String inName = op.arg(2).getVarName(); + SDVariable inSDV = sameDiff.getVariable(inName); + INDArray arr = getArray(inSDV, opInputs, allIterInputs); + Preconditions.checkState(arr != null, "Could not find array for %s", inName); + + Preconditions.checkState(tensorArrays.containsKey(tArr), "Tensor array does not exist for %s", tArr); + //TODO is this always safe to insert by index for all execution orders? + List l = tensorArrays.get(tArr); //.set(idx, arr); + while (l.size() <= idx) { + //Can't use set(int, E) if index >= size + l.add(null); + } + l.set(idx, arr); + + //Add a dependency + Dep d = new ExecDoneDep(); + arrayUseTracker.addDependency(arr, d); + + //Return dummy array + INDArray scalar = mmgr.allocate(false, DataType.FLOAT).assign(0.0); + return new INDArray[]{scalar}; + } else if (op instanceof TensorArraySize) { + //Index 0 is the TensorArray (or dummy variable that represents it) + SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array + //Work out the varid (frame/iteration) of the tensor array: + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); + if (tArr == null && allIterInputs != null) { + tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); + } + List l = tensorArrays.get(tArr); + Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); + + INDArray scalar = mmgr.allocate(false, DataType.INT).assign(l.size()); + return new INDArray[]{scalar}; + } else if (op instanceof TensorArrayConcat) { + SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); + if (tArr == null && allIterInputs != null) { + tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); + } + List l = tensorArrays.get(tArr); + + Concat c = new Concat(0, l.toArray(new INDArray[0])); + List shape = c.calculateOutputShape(); + INDArray out = mmgr.allocate(false, shape.get(0)); + c.setOutputArgument(0, out); + Nd4j.exec(c); + return new INDArray[]{out}; + } else if (op instanceof TensorArrayGather) { + //Input 0: the TensorArray + //Input 1: the indices (1d integer vector) + + SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); + if (tArr == null && allIterInputs != null) { + tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); + } + List l = tensorArrays.get(tArr); + Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); + + String indicesName = op.arg(1).getVarName(); + SDVariable indicesSDV = sameDiff.getVariable(indicesName); + INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs); + Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", idxArr, indicesName); + Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayGather should be an integer type, got %s for array %s", idxArr.dataType(), indicesName); + + int[] idxArrInt = idxArr.toIntVector(); + + //Edge case: -1 means "all" + List newList = new ArrayList<>(); + if (idxArrInt.length == 1 && idxArrInt[0] == -1) { + newList.addAll(l); + } else { + for (int id : idxArrInt) { + Preconditions.checkState(id >= 0, "Index for TensorArrayGather must be >= 0, got %s", id); + newList.add(l.get(id)); + } + } + + Stack s = new Stack(newList.toArray(new INDArray[0]), null, 0); + List shape = s.calculateOutputShape(); + INDArray out = mmgr.allocate(false, shape.get(0)); + s.setOutputArgument(0, out); + Nd4j.exec(s); + return new INDArray[]{out}; + } else if (op instanceof TensorArrayScatter) { + //Scatter values from a rank (N+1)d tensor into specific indices of the TensorArray + //Input 0: the TensorArray + //Input 1: the indices (1d integer vector) + //Input 2: The values to scatter + + SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array + TensorArray ta = (TensorArray) sameDiff.getVariableOutputOp(inTensorArray.getVarName()); + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); + if (tArr == null && allIterInputs != null) { + tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); + } + List l = tensorArrays.get(tArr); + Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); + + String indicesName = op.arg(1).getVarName(); + SDVariable indicesSDV = sameDiff.getVariable(indicesName); + INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs); + Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", idxArr, indicesName); + Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", idxArr.dataType(), indicesName); + int[] idxs = idxArr.toIntVector(); + + String valuesName = op.arg(2).getVarName(); + SDVariable valuesSDV = sameDiff.getVariable(valuesName); + INDArray valuesArr = getArray(valuesSDV, opInputs, allIterInputs); + + while (l.size() <= idxs.length) { //Can't use set(int, E) if index >= size + l.add(null); + } + + //Edge case: idxs being [-1] means "all sub arrays" (i.e., "unstack" case) + if (idxs.length == 1 && idxs[0] == -1) { + idxs = ArrayUtil.range(0, (int) valuesArr.size(0)); + } + + INDArrayIndex[] idx = ArrayUtil.nTimes(valuesArr.rank(), NDArrayIndex.all(), INDArrayIndex.class); + for (int i = 0; i < idxs.length; i++) { + idx[0] = NDArrayIndex.point(i); + INDArray get = mmgr.dup(valuesArr.get(idx)); + int outIdx = idxs[i]; + if (valuesArr.rank() == 1 && get.rank() > 0) { + get = get.reshape(); + } + l.set(outIdx, get); + + //Add dependency for values array until end of execution + arrayUseTracker.addDependency(get, new ExecDoneDep()); + } + + //Return dummy array + INDArray scalar = mmgr.allocate(false, DataType.FLOAT).assign(0.0); + return new INDArray[]{scalar}; + } else if (op instanceof TensorArraySplit) { + //Split values from a rank (N+1)d tensor into sequential indices of the TensorArray + //For example, orig=[8,2] sizearray with split (4,4) means TensorArray[0] = orig[0:4,:] and TensorArray[1] = orig[4:8,:] + //Input 0: the TensorArray + //Input 1: The values to split + //Input 2: the size of each split (1d integer vector) + + SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); + if (tArr == null && allIterInputs != null) { + tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); + } + List l = tensorArrays.get(tArr); + Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); + + String splitName = op.arg(1).getVarName(); + INDArray splitArr = getArray(sameDiff.getVariable(splitName), opInputs, allIterInputs); + + + String sizeName = op.arg(2).getVarName(); + SDVariable sizeSDV = sameDiff.getVariable(sizeName); + INDArray sizeArr = getArray(sizeSDV, opInputs, allIterInputs); + Preconditions.checkState(sizeArr.isVector(), "Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", sizeArr, sizeName); + Preconditions.checkState(sizeArr.dataType().isIntType(), "Indices variable for TensorArraySplit should be an integer type, got %s for array %s", sizeArr.dataType(), sizeName); + int[] sizes = sizeArr.toIntVector(); + + while (l.size() <= sizes.length) { //Can't use set(int, E) if index >= size + l.add(null); + } + + INDArrayIndex[] idx = ArrayUtil.nTimes(splitArr.rank(), NDArrayIndex.all(), INDArrayIndex.class); + int soFar = 0; + for (int i = 0; i < sizes.length; i++) { + idx[0] = NDArrayIndex.interval(soFar, soFar + sizes[i]); + INDArray sub = mmgr.dup(splitArr.get(idx)); + l.set(i, sub); + soFar += sizes[i]; + + //Add dependency for values array until end of execution + arrayUseTracker.addDependency(sub, new ExecDoneDep()); + } + + //Return dummy array + INDArray scalar = mmgr.allocate(false, DataType.FLOAT).assign(0.0); + return new INDArray[]{scalar}; + } else { + throw new IllegalStateException("Execution support not yet implemented for: " + op.getClass().getName()); + } + } + + @Override public INDArray getConstantOrVariable(String variableName) { SDVariable v = sameDiff.getVariable(variableName); @@ -522,21 +749,19 @@ public class InferenceSession extends AbstractSession opInputs, Set allIterInputs, - Set constAndPhInputs, Map placeholderValues) { + public SameDiffOp getAndParameterizeOp(String opName, FrameIter frameIter, Set opInputs, Set allIterInputs, + Set constAndPhInputs, Map placeholderValues, Set allReqVariables) { + SameDiffOp sdo = sameDiff.getOps().get(opName); + DifferentialFunction df = sdo.getOp(); - DifferentialFunction df = sameDiff.getOpById(opName); + //TODO Switch to OpContext - and make sure executing like that is thread safe (i.e., array fields in ops are not used etc) - //TODO We should clone these ops - probably - as we don't want them shared between threads/sessions! - //But let's only clone them *once* and cache in inference session - not on every exec + Preconditions.checkNotNull(df, "No differential function found with name \"%s\"", opName); - Preconditions.checkNotNull(df, "No differential function fond with name %s", opName); - - if(df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration || - df instanceof Merge || df instanceof Switch || df instanceof If || df instanceof While || - df instanceof BaseTensorOp){ + if (df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration || + df instanceof Merge || df instanceof Switch || df instanceof BaseTensorOp) { //Control dependencies and tensor ops (like TensorArray, TensorArrayRead etc) don't need inputs set, execution is a special case - return df; + return sdo; } //Infer the args based on the inputs (variable + frame + iteration) @@ -546,123 +771,41 @@ public class InferenceSession extends AbstractSession constEnterInputs = null; - if(numArgs != (numNonConstIns + numConstPhIns + numNonConstInsAllIters)){ - boolean anyConstEnterInputs = false; - SDVariable[] args = df.args(); - for(SDVariable v : args){ - Variable var = sameDiff.getVariables().get(v.getVarName()); - //Nested enter case: - DifferentialFunction inputVarFn = (var.getOutputOfOp() == null ? null : sameDiff.getOps().get(var.getOutputOfOp()).getOp()); - if(inputVarFn instanceof Enter && ((Enter)inputVarFn).isConstant()){ - anyConstEnterInputs = true; - if(constEnterInputs == null) - constEnterInputs = new HashSet<>(); - constEnterInputs.add(v.getVarName()); - } - } - - int constEnterInputCount = 0; - if(anyConstEnterInputs){ - /* - 2019/01/26: AB - Resolve nested enter inputs (constants 2+ enters in) - Why this hack is necessary: consider the following (sub) graph: constX -> Enter(a) -> Enter(b) -> opY - On iterations (a=0, b=0) all is well, opY gets triggered as normal. - On iterations (a>0, b=*) the "opY is available for exec" won't be triggered. - This is because Enter(a) is only executed once, on iteration 0 of the outer loop. - Consequently, Enter(b) is not triggered as available on iteration 1+. - When we do the lookup for the actual array to use for op execution (i.e., get inputs for opY(a=1,b=0)) - it won't be found. - This is a bit of an ugly hack, though I've yet to find a cleaner solution. - It should only be required with the combination of: constants, 2 levels of enters, and more than 1 iteration in each loop. - */ - - //For example, const -> Enter(a) -> Enter(b) -> op; in this case, the input to Op (at any frame/iteration) should should - // be the constant value - which is recorded as (frame="a",iter=0,parent=(frame="b",iter=0)) - for(String s : constEnterInputs){ - //First: check if this has already been provided - if(constAndPhInputs != null && constAndPhInputs.contains(s)){ - //already resolved/provided - continue; - } - boolean found = false; - if(allIterInputs != null) { - for (VarId vid : allIterInputs) { - if (s.equals(vid.getVariable())) { - //Already resolved/provided - found = true; - break; - } - } - } - if(found) - continue; - - constEnterInputCount++; - } - } - - if(numArgs > 1){ + if (numArgs != (numNonConstIns + numConstPhIns + numNonConstInsAllIters)) { + if (numArgs > 1) { //Might be due to repeated inputs Set uniqueArgNames = new HashSet<>(); Collections.addAll(uniqueArgNames, argNames); - Preconditions.checkState(uniqueArgNames.size() == (numNonConstIns + numConstPhIns + numNonConstInsAllIters + constEnterInputCount), + Preconditions.checkState(uniqueArgNames.size() == (numNonConstIns + numConstPhIns + numNonConstInsAllIters), "Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", df.getClass().getSimpleName(), opName, uniqueArgNames, opInputs, constAndPhInputs); } else { - Preconditions.checkState(numArgs == (numNonConstIns + numConstPhIns + constEnterInputCount), + Preconditions.checkState(numArgs == (numNonConstIns + numConstPhIns), "Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", df.getClass().getSimpleName(), opName, argNames, opInputs, constAndPhInputs); } } INDArray[] args = null; - if(argNames != null && argNames.length > 0) { + if (argNames != null && argNames.length > 0) { args = new INDArray[argNames.length]; int i = 0; - for(String s : argNames){ + for (String s : argNames) { SDVariable v = sameDiff.getVariable(s); - if(v.isConstant()) { + if (v.isConstant()) { args[i] = v.getArr(); - } else if(v.isPlaceHolder()) { + } else if (v.getVariableType() == VariableType.VARIABLE) { + args[i] = v.getArr(); + } else if (v.isPlaceHolder()) { Preconditions.checkState(placeholderValues != null && placeholderValues.containsKey(s), "No array provided for placeholder %s", s); args[i] = placeholderValues.get(s); - } else if(constEnterInputs != null && constEnterInputs.contains(s)){ - //For enter nodes that are constants, we want iteration 0 in all frames in the heirarchy - //For example, const -> Enter(a) -> Enter(b) -> op; in this case, the input to Op (at any frame/iteration) should should - // be the constant value - which is recorded as (frame="a",iter=0,parent=(frame="b",iter=0)) - VarId vid = newVarId(s, frameIter.clone()); - vid.setIteration(0); - FrameIter toZero = vid.getParentFrame(); - while(toZero != null){ - toZero.setIteration(0); - toZero = toZero.getParentFrame(); - } - INDArray arr = this.nodeOutputs.get(vid); - args[i] = arr; } else { - if(opInputs != null) { - for (VarId vid : opInputs) { - if (vid.getVariable().equals(s)) { - args[i] = this.nodeOutputs.get(vid); - break; - } - } - } - if(args[i] == null && allIterInputs != null){ - for(VarId vid : allIterInputs){ - if(vid.getVariable().equals(s)){ - args[i] = this.nodeOutputs.get(vid); - break; - } - } - } + VarId vid = lookup(s, opInputs, allIterInputs, true); + args[i] = nodeOutputs.get(vid); } Preconditions.checkNotNull(args[i], "Could not parameterize op %s: array %s (variable %s) is null", opName, i, v.getVarName()); i++; } - } //Set the op inputs and output arguments @@ -671,19 +814,24 @@ public class InferenceSession extends AbstractSession 0; - if(df instanceof CustomOp){ + if (df instanceof CustomOp) { DynamicCustomOp customOp = (DynamicCustomOp) df; - if(args != null) { + if (args != null) { customOp.setInputArguments(args); } - df.resolvePropertiesFromSameDiffBeforeExecution(); + if (df instanceof Identity) { + //We don't need to allocate an output array for Identity, we pass through the input array without copying + return sdo; + } + + df.resolvePropertiesFromSameDiffBeforeExecution(); //TODO This is to be removed List outShape = customOp.calculateOutputShape(); Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName()); String[] outNames = df.outputVariablesNames(); Preconditions.checkState(outNames.length == outShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation" + " with %s outputs (number of shapes and outputs must be equal)", df.opName(), outShape.size(), outNames.length); - for( int i=0; i 0){ + if (args != null && args.length > 0) { op.setX(args[0]); if (args.length == 2 && !axisArg) op.setY(args[1]); @@ -749,51 +894,105 @@ public class InferenceSession extends AbstractSession outputShape = ((BaseOp) op).calculateOutputShape(); Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass()); INDArray z = op.z(); - if (z == null || !outputShape.get(0).equals(z.shapeDescriptor()) || isLoop) { + if (z == null || z.wasClosed() || !outputShape.get(0).equals(z.shapeDescriptor()) || isLoop) { if (log.isTraceEnabled()) { log.trace("Existing op result (z) array shape for op {} was {}, allocating new array of shape {}", op.getClass().getSimpleName(), (z == null ? null : Arrays.toString(z.shape())), outputShape.get(0).toString()); } LongShapeDescriptor lsd = outputShape.get(0); - try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - z = Nd4j.create(lsd, false); - } + + boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]); + z = mmgr.allocate(isOutput, lsd); op.setZ(z); } } df.resolvePropertiesFromSameDiffBeforeExecution(); } - return df; + return sdo; } - protected INDArray getArray(SDVariable sdv, Collection opInputs, Collection allIterInputs){ + protected INDArray getArray(SDVariable sdv, Collection opInputs, Collection allIterInputs) { String n = sdv.getVarName(); - if(sdv.getVariableType() == VariableType.CONSTANT || sdv.getVariableType() == VariableType.VARIABLE){ + if (sdv.getVariableType() == VariableType.CONSTANT || sdv.getVariableType() == VariableType.VARIABLE) { return getConstantOrVariable(n); } else { - VarId inVarId = null; - if(opInputs != null){ - inVarId = lookup(n, opInputs, false); - } - if(inVarId == null && allIterInputs != null && !allIterInputs.isEmpty()){ - inVarId = lookup(n, allIterInputs, false); - } - Preconditions.checkState(inVarId != null,"Could not find array for variable %s", sdv.getVarName()); + VarId inVarId = lookup(n, opInputs, allIterInputs, false); + Preconditions.checkState(inVarId != null, "Could not find array for variable %s", sdv.getVarName()); return nodeOutputs.get(inVarId); } } + + @Data + public abstract static class Dep { + protected String frame; + protected FrameIter parentFrame; + } + + @AllArgsConstructor + @Data + @EqualsAndHashCode(callSuper = true) + public static class OpDep extends Dep { + protected String opName; + protected int iter; + + protected OpDep(@NonNull String opName, @NonNull String frame, int iter, FrameIter parentFrame) { + this.opName = opName; + this.frame = frame; + this.iter = iter; + this.parentFrame = parentFrame; + } + + @Override + public String toString() { + return "OpDep(" + opName + ",frame=" + frame + ",iter=" + iter + (parentFrame == null ? "" : ",parent=" + parentFrame) + ")"; + } + } + + @Data + @EqualsAndHashCode(callSuper = true) + @AllArgsConstructor + protected static class PlaceholderDep extends Dep { + protected String phName; + } + + @Data + @EqualsAndHashCode(callSuper = true) + @AllArgsConstructor + protected static class VariableDep extends Dep { + protected String varName; + } + + @Data + @EqualsAndHashCode(callSuper = true) + @AllArgsConstructor + protected static class ConstantDep extends Dep { + protected String constName; + } + + @Data + @EqualsAndHashCode(callSuper = true) + @AllArgsConstructor + protected static class ReqOutputDep extends Dep { + protected String outputName; + } + + @Data + @EqualsAndHashCode(callSuper = true) + @NoArgsConstructor + protected static class ExecDoneDep extends Dep { + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SameDiffOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SameDiffOp.java index de3e96c2e..8e9b45067 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SameDiffOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SameDiffOp.java @@ -30,8 +30,10 @@ import java.util.List; @Builder public class SameDiffOp { protected String name; - protected DifferentialFunction op; //Actual op (note: should be mutable: i.e., cloneable, no arrays set) - protected List inputsToOp; //Name of SDVariables as input - protected List outputsOfOp; //Name of SDVariables as output - protected List controlDeps; //Name of SDVariables as control dependencies (not data inputs, but need to be available before exec) + protected DifferentialFunction op; //Actual op (note: should be mutable: i.e., cloneable, no arrays set) + protected List inputsToOp; //Name of SDVariables as input + protected List outputsOfOp; //Name of SDVariables as output + protected List controlDeps; //Name of SDVariables as control dependencies (not data inputs, but need to be available before exec) + protected List varControlDeps; //Variables (constants, placeholders, etc) that are control dependencies for this op + protected List controlDepFor; //Name of the variables that this op is a control dependency for } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SessionMemMgr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SessionMemMgr.java new file mode 100644 index 000000000..b54db548a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SessionMemMgr.java @@ -0,0 +1,60 @@ +package org.nd4j.autodiff.samediff.internal; + +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; + +import java.io.Closeable; + +/** + * SessionMemMgr - aka "Session Memory Manager" is responsible for allocating, managing, and deallocating memory used + * during SameDiff execution.
+ * This interface allows different memory management strategies to be used, abstracted away from the actual graph + * execution logic + * + * @author Alex Black + */ +public interface SessionMemMgr extends Closeable { + + /** + * Allocate an array with the specified datatype and shape.
+ * NOTE: This array should be assumed to be uninitialized - i.e., contains random values. + * + * @param detached If true: the array is safe to return outside of the SameDiff session run (for example, the array + * is one that may be returned to the user) + * @param dataType Datatype of the returned array + * @param shape Array shape + * @return The newly allocated (uninitialized) array + */ + INDArray allocate(boolean detached, DataType dataType, long... shape); + + /** + * As per {@link #allocate(boolean, DataType, long...)} but from a LongShapeDescriptor instead + */ + INDArray allocate(boolean detached, LongShapeDescriptor descriptor); + + /** + * Allocate an uninitialized array with the same datatype and shape as the specified array + */ + INDArray ulike(INDArray arr); + + /** + * Duplicate the specified array, to an array that is managed/allocated by the session memory manager + */ + INDArray dup(INDArray arr); + + /** + * Release the array. All arrays allocated via one of the allocate methods should be returned here once they are no + * longer used, and all references to them should be cleared. + * After calling release, anything could occur to the array - deallocated, workspace closed, reused, etc. + * + * @param array The array that can be released + */ + void release(INDArray array); + + /** + * Close the session memory manager and clean up any memory / resources, if any + */ + void close(); + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java new file mode 100644 index 000000000..43082c4de --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java @@ -0,0 +1,232 @@ +package org.nd4j.autodiff.samediff.internal; + +import com.sun.prism.paint.Gradient; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.Listener; +import org.nd4j.autodiff.listeners.Loss; +import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.TrainingConfig; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.learning.GradientUpdater; +import org.nd4j.linalg.learning.regularization.Regularization; +import org.nd4j.linalg.primitives.AtomicDouble; + +import java.util.*; + +/** + * TrainingSession extends InferenceSession, to add training-specific functionality:
+ * - Application of regularization (L1, L2, weight decay etc)
+ * - Inline updating of variables, using updater/optimizer (Adam, Nesterov, SGD, etc)
+ * - Calculation of regularization scores (Score for L1, L2, etc) + * + * @author Alex Black + */ +@Slf4j +public class TrainingSession extends InferenceSession { + + protected TrainingConfig config; + protected Map gradVarToVarMap; + protected Map updaters; + protected Map lossVarsToLossIdx; + protected double[] currIterLoss; + protected Map, AtomicDouble> currIterRegLoss; + protected List listeners; + + + public TrainingSession(SameDiff sameDiff) { + super(sameDiff); + } + + /** + * Perform one iteration of training - i.e., do forward and backward passes, and update the parameters + * + * @param config Training configuration + * @param placeholders Current placeholders + * @param paramsToTrain Set of parameters that will be trained + * @param updaters Current updater state + * @param batch Current data/batch (mainly for listeners, should have already been converted to placeholders map) + * @param lossVariables Loss variables (names) + * @param listeners Listeners (if any) + * @param at Current epoch, iteration, etc + * @return The Loss at the current iteration + */ + public Loss trainingIteration(TrainingConfig config, Map placeholders, Set paramsToTrain, Map updaters, + MultiDataSet batch, List lossVariables, List listeners, At at) { + this.config = config; + this.updaters = updaters; + + //Preprocess listeners, get the relevant ones + if (listeners == null) { + this.listeners = null; + } else { + List filtered = new ArrayList<>(); + for (Listener l : listeners) { + if (l.isActive(at.operation())) { + filtered.add(l); + } + } + this.listeners = filtered.isEmpty() ? null : filtered; + } + + List requiredActivations = new ArrayList<>(); + gradVarToVarMap = new HashMap<>(); //Key: gradient variable. Value: variable that the key is gradient for + for (String s : paramsToTrain) { + Preconditions.checkState(sameDiff.hasVariable(s), "SameDiff instance does not have a variable with name \"%s\"", s); + SDVariable v = sameDiff.getVariable(s); + Preconditions.checkState(v.getVariableType() == VariableType.VARIABLE, "Can only train VARIABLE type variable - \"%s\" has type %s", + s, v.getVariableType()); + SDVariable grad = sameDiff.getVariable(s).getGradient(); + if (grad == null) { + //In some cases, a variable won't actually impact the loss value, and hence won't have a gradient associated with it + //For example: floatVar -> cast to integer -> cast to float -> sum -> loss + //In this case, the gradient of floatVar isn't defined (due to no floating point connection to the loss) + continue; + } + + requiredActivations.add(grad.getVarName()); + + gradVarToVarMap.put(grad.getVarName(), s); + } + + //Set up losses + lossVarsToLossIdx = new LinkedHashMap<>(); + List lossVars; + currIterLoss = new double[lossVariables.size()]; + currIterRegLoss = new HashMap<>(); + for (int i = 0; i < lossVariables.size(); i++) { + lossVarsToLossIdx.put(lossVariables.get(i), i); + } + + //Do training iteration + List outputVars = new ArrayList<>(gradVarToVarMap.keySet()); //TODO this should be empty, and grads calculated in requiredActivations + Map m = output(outputVars, placeholders, batch, requiredActivations, listeners, at); + + + double[] finalLoss = new double[currIterLoss.length + currIterRegLoss.size()]; + System.arraycopy(currIterLoss, 0, finalLoss, 0, currIterLoss.length); + if (currIterRegLoss.size() > 0) { + lossVars = new ArrayList<>(lossVariables.size() + currIterRegLoss.size()); + lossVars.addAll(lossVariables); + int s = currIterRegLoss.size(); + //Collect regularization losses + for (Map.Entry, AtomicDouble> entry : currIterRegLoss.entrySet()) { + lossVars.add(entry.getKey().getSimpleName()); + finalLoss[s] = entry.getValue().get(); + } + } else { + lossVars = lossVariables; + } + + Loss loss = new Loss(lossVars, finalLoss); + if (listeners != null) { + for (Listener l : listeners) { + if (l.isActive(Operation.TRAINING)) { + l.iterationDone(sameDiff, at, batch, loss); + } + } + } + + return loss; + } + + @Override + public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, + Set constAndPhInputs, List listeners, At at, MultiDataSet batch, Set allReqVariables) { + //Get outputs from InferenceSession + INDArray[] out = super.getOutputs(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables); + + List outputs = op.getOutputsOfOp(); + int outIdx = 0; + for (String s : outputs) { + //If this is a loss variable - record it + if (lossVarsToLossIdx.containsKey(s)) { + int lossIdx = lossVarsToLossIdx.get(s); + INDArray arr = out[outIdx]; + double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue(); + currIterLoss[lossIdx] += l; + } + + //If this is a gradient variable - apply the updater and update the parameter array in-line + if (gradVarToVarMap.containsKey(s)) { + String varName = gradVarToVarMap.get(s); + //log.info("Calculated gradient for variable \"{}\": (grad var name: \"{}\")", varName, s); + + Variable gradVar = sameDiff.getVariables().get(s); + if (gradVar.getInputsForOp() != null && gradVar.getInputsForOp().isEmpty()) { + //Should be rare, and we should handle this by tracking dependencies, and only update when safe + // (i.e., dependency tracking) + throw new IllegalStateException("Op depends on gradient variable: " + s + " for variable " + varName); + } + + GradientUpdater u = updaters.get(varName); + Preconditions.checkState(u != null, "No updater found for variable \"%s\"", varName); + + Variable var = sameDiff.getVariables().get(varName); + INDArray gradArr = out[outIdx]; + INDArray paramArr = var.getVariable().getArr(); + + //Pre-updater regularization (L1, L2) + List r = config.getRegularization(); + if (r != null && r.size() > 0) { + double lr = config.getUpdater().hasLearningRate() ? config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0; + for (Regularization reg : r) { + if (reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER) { + if (this.listeners != null) { + double score = reg.score(paramArr, at.iteration(), at.epoch()); + if (!currIterRegLoss.containsKey(reg.getClass())) { + currIterRegLoss.put(reg.getClass(), new AtomicDouble()); + } + currIterRegLoss.get(reg.getClass()).addAndGet(score); + } + reg.apply(paramArr, gradArr, lr, at.iteration(), at.epoch()); + } + } + } + + u.applyUpdater(gradArr, at.iteration(), at.epoch()); + + //Post-apply regularization (weight decay) + if (r != null && r.size() > 0) { + double lr = config.getUpdater().hasLearningRate() ? config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0; + for (Regularization reg : r) { + if (reg.applyStep() == Regularization.ApplyStep.POST_UPDATER) { + if (this.listeners != null) { + double score = reg.score(paramArr, at.iteration(), at.epoch()); + if (!currIterRegLoss.containsKey(reg.getClass())) { + currIterRegLoss.put(reg.getClass(), new AtomicDouble()); + } + currIterRegLoss.get(reg.getClass()).addAndGet(score); + } + reg.apply(paramArr, gradArr, lr, at.iteration(), at.epoch()); + } + } + } + + if (listeners != null) { + for (Listener l : listeners) { + if (l.isActive(at.operation())) + l.preUpdate(sameDiff, at, var, gradArr); + } + } + + //Update: + if (config.isMinimize()) { + paramArr.subi(gradArr); + } else { + paramArr.addi(gradArr); + } + log.trace("Applied updater to gradient and updated variable: {}", varName); + } + + outIdx++; + } + + return out; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java index 670b21dda..e8041955b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java @@ -35,8 +35,7 @@ public class Variable { protected List controlDepsForOp; //if a op control dependency (x -> opY) exists, then "opY" will be in this list protected List controlDepsForVar; //if a variable control dependency (x -> varY) exists, then "varY" will be in this list protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of - protected List controlDeps; //Control dependencies: name of variables that must be available before this variable is considered available for execution - protected int outputOfOpIdx; //Index of the output for the op (say, variable is output number 2 of op "outputOfOp") + protected List controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution protected SDVariable gradient; //Variable corresponding to the gradient of this variable protected int variableIndex = -1; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/AbstractMemoryMgr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/AbstractMemoryMgr.java new file mode 100644 index 000000000..e498deaf5 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/AbstractMemoryMgr.java @@ -0,0 +1,25 @@ +package org.nd4j.autodiff.samediff.internal.memory; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.internal.SessionMemMgr; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * Abstract memory manager, that implements ulike and dup methods using the underlying allocate methods + * + * @author Alex Black + */ +public abstract class AbstractMemoryMgr implements SessionMemMgr { + + @Override + public INDArray ulike(@NonNull INDArray arr) { + return allocate(false, arr.dataType(), arr.shape()); + } + + @Override + public INDArray dup(@NonNull INDArray arr) { + INDArray out = ulike(arr); + out.assign(arr); + return out; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/ArrayCloseMemoryMgr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/ArrayCloseMemoryMgr.java new file mode 100644 index 000000000..24992c50b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/ArrayCloseMemoryMgr.java @@ -0,0 +1,43 @@ +package org.nd4j.autodiff.samediff.internal.memory; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.internal.SessionMemMgr; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.factory.Nd4j; + +/** + * A simple memory management strategy that deallocates memory as soon as it is no longer needed.
+ * This should result in a minimal amount of memory, but will have some overhead - notably, the cost of deallocating + * and reallocating memory all the time. + * + * @author Alex Black + */ +@Slf4j +public class ArrayCloseMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr { + + @Override + public INDArray allocate(boolean detached, DataType dataType, long... shape) { + return Nd4j.createUninitialized(dataType, shape); + } + + @Override + public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) { + return Nd4j.create(descriptor, false); + } + + @Override + public void release(@NonNull INDArray array) { + if (!array.wasClosed() && array.closeable()) { + array.close(); + log.trace("Closed array (deallocated) - id={}", array.getId()); + } + } + + @Override + public void close() { + //No-op + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/CloseValidationMemoryMgr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/CloseValidationMemoryMgr.java new file mode 100644 index 000000000..8417bfb35 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/CloseValidationMemoryMgr.java @@ -0,0 +1,168 @@ +package org.nd4j.autodiff.samediff.internal.memory; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.internal.DependencyList; +import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker; +import org.nd4j.autodiff.samediff.internal.InferenceSession; +import org.nd4j.autodiff.samediff.internal.SessionMemMgr; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.primitives.Pair; + +import java.util.*; + +/** + * A {@link SessionMemMgr} that wraps an existing memory manager, to ensure that:
+ * - All arrays that are supposed to be closed, have been closed
+ * - Arrays are only passed to the close method exactly one (unless they are requested outputs)
+ * - Arrays that are passed to the close method were originally allocated by the session memory manager
+ *
+ * How to use:
+ * 1. Perform an inference or training iteration, as normal
+ * 2. Call {@link #assertAllReleasedExcept(Collection)} with the output arrays
+ *

+ * NOTE: This is intended for debugging and testing only + * + * @author Alex Black + */ +@Slf4j +public class CloseValidationMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr { + + private final SameDiff sd; + private final SessionMemMgr underlying; + private final Map released = new IdentityHashMap<>(); + + public CloseValidationMemoryMgr(SameDiff sd, SessionMemMgr underlying) { + this.sd = sd; + this.underlying = underlying; + } + + @Override + public INDArray allocate(boolean detached, DataType dataType, long... shape) { + INDArray out = underlying.allocate(detached, dataType, shape); + released.put(out, false); + return out; + } + + @Override + public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) { + INDArray out = underlying.allocate(detached, descriptor); + released.put(out, false); + return out; + } + + @Override + public void release(INDArray array) { + Preconditions.checkState(released.containsKey(array), "Attempting to release an array that was not allocated by" + + " this memory manager: id=%s", array.getId()); + if (released.get(array)) { + //Already released + InferenceSession is = sd.getSessions().get(Thread.currentThread().getId()); + IdentityDependencyTracker arrayUseTracker = is.getArrayUseTracker(); + DependencyList dl = arrayUseTracker.getDependencies(array); + System.out.println(dl); + if (dl.getDependencies() != null) { + for (InferenceSession.Dep d : dl.getDependencies()) { + System.out.println(d + ": " + arrayUseTracker.isSatisfied(d)); + } + } + if (dl.getOrDependencies() != null) { + for (Pair p : dl.getOrDependencies()) { + System.out.println(p + " - (" + arrayUseTracker.isSatisfied(p.getFirst()) + "," + arrayUseTracker.isSatisfied(p.getSecond())); + } + } + } + Preconditions.checkState(!released.get(array), "Attempting to release an array that was already deallocated by" + + " an earlier release call to this memory manager: id=%s", array.getId()); + log.trace("Released array: id = {}", array.getId()); + released.put(array, true); + } + + @Override + public void close() { + underlying.close(); + } + + /** + * Check that all arrays have been released (after an inference call) except for the specified arrays. + * + * @param except Arrays that should not have been closed (usually network outputs) + */ + public void assertAllReleasedExcept(@NonNull Collection except) { + Set allVarPhConst = null; + + for (INDArray arr : except) { + if (!released.containsKey(arr)) { + //Check if constant, variable or placeholder - maybe user requested that out + if (allVarPhConst == null) + allVarPhConst = identitySetAllConstPhVar(); + if (allVarPhConst.contains(arr)) + continue; //OK - output is a constant, variable or placeholder, hence it's fine it's not allocated by the memory manager + + throw new IllegalStateException("Array " + arr.getId() + " was not originally allocated by the memory manager"); + } + + boolean released = this.released.get(arr); + if (released) { + throw new IllegalStateException("Specified output array (id=" + arr.getId() + ") should not have been deallocated but was"); + } + } + + Set exceptSet = Collections.newSetFromMap(new IdentityHashMap()); + exceptSet.addAll(except); + + int numNotClosed = 0; + Set notReleased = Collections.newSetFromMap(new IdentityHashMap()); + InferenceSession is = sd.getSessions().get(Thread.currentThread().getId()); + IdentityDependencyTracker arrayUseTracker = is.getArrayUseTracker(); + for (Map.Entry e : released.entrySet()) { + INDArray a = e.getKey(); + if (!exceptSet.contains(a)) { + boolean b = e.getValue(); + if (!b) { + notReleased.add(a); + numNotClosed++; + log.info("Not released: array id {}", a.getId()); + DependencyList list = arrayUseTracker.getDependencies(a); + List l = list.getDependencies(); + List> l2 = list.getOrDependencies(); + if (l != null) { + for (InferenceSession.Dep d : l) { + if (!arrayUseTracker.isSatisfied(d)) { + log.info(" Not satisfied: {}", d); + } + } + } + if (l2 != null) { + for (Pair d : l2) { + if (!arrayUseTracker.isSatisfied(d.getFirst()) && !arrayUseTracker.isSatisfied(d.getSecond())) { + log.info(" Not satisfied: {}", d); + } + } + } + } + } + } + + if (numNotClosed > 0) { + System.out.println(sd.summary()); + throw new IllegalStateException(numNotClosed + " arrays were not released but should have been"); + } + } + + protected Set identitySetAllConstPhVar() { + Set set = Collections.newSetFromMap(new IdentityHashMap()); + for (SDVariable v : sd.variables()) { + if (v.getVariableType() == VariableType.VARIABLE || v.getVariableType() == VariableType.CONSTANT || v.getVariableType() == VariableType.PLACEHOLDER) { + set.add(v.getArr()); + } + } + return set; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/NoOpMemoryMgr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/NoOpMemoryMgr.java new file mode 100644 index 000000000..30b891c2f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/NoOpMemoryMgr.java @@ -0,0 +1,42 @@ +package org.nd4j.autodiff.samediff.internal.memory; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.internal.SessionMemMgr; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.factory.Nd4j; + +/** + * A simple "no-op" memory manager that relies on JVM garbage collector for memory management. + * Assuming other references have been cleared (they should have been) the arrays will be cleaned up by the + * garbage collector at some point. + * + * This memory management strategy is not recommended for performance or memory reasons, and should only be used + * for testing and debugging purposes + * + * @author Alex Black + */ +public class NoOpMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr { + + @Override + public INDArray allocate(boolean detached, DataType dataType, long... shape) { + return Nd4j.createUninitialized(dataType, shape); + } + + @Override + public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) { + return Nd4j.create(descriptor, false); + } + + @Override + public void release(@NonNull INDArray array) { + //No-op, rely on GC to clear arrays + } + + @Override + public void close() { + //No-op + } + +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index a17cb41b1..668a7a4a9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -90,10 +90,10 @@ public class SDNN extends SDOps { } /** - * @see #biasAdd(String, SDVariable, SDVariable) + * @see #biasAdd(String, SDVariable, SDVariable, boolean) */ - public SDVariable biasAdd(SDVariable input, SDVariable bias) { - return biasAdd(null, input, bias); + public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) { + return biasAdd(null, input, bias, nchw); } /** @@ -102,12 +102,14 @@ public class SDNN extends SDOps { * @param name Name of the output variable * @param input 4d input variable * @param bias 1d bias + * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels]. + * Unused for 2d inputs * @return Output variable */ - public SDVariable biasAdd(String name, SDVariable input, SDVariable bias) { + public SDVariable biasAdd(String name, SDVariable input, SDVariable bias, boolean nchw) { validateFloatingPoint("biasAdd", "input", input); validateFloatingPoint("biasAdd", "bias", bias); - SDVariable ret = f().biasAdd(input, bias); + SDVariable ret = f().biasAdd(input, bias, nchw); return updateVariableNameAndReference(ret, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index cce38cf24..39e7e479f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -16,6 +16,7 @@ package org.nd4j.autodiff.samediff.serde; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.shade.guava.primitives.Ints; import com.google.flatbuffers.FlatBufferBuilder; import java.nio.ByteOrder; @@ -847,6 +848,28 @@ public class FlatBuffersMapper { } int outTypesOffset = FlatNode.createOutputTypesVector(bufferBuilder, outTypes); + //Control dependencies: + SameDiffOp sdo = sameDiff.getOps().get(node.getOwnName()); + + int opCds = 0; + int[] opCdsArr = mapOrNull(sdo.getControlDeps(), bufferBuilder); + if(opCdsArr != null){ + opCds = FlatNode.createControlDepsVector(bufferBuilder, opCdsArr); + } + + int varCds = 0; + int[] varCdsArr = mapOrNull(sdo.getVarControlDeps(), bufferBuilder); + if(varCdsArr != null){ + varCds = FlatNode.createVarControlDepsVector(bufferBuilder, varCdsArr); + } + + int cdsFor = 0; + int[] cdsForArr = mapOrNull(sdo.getControlDepFor(), bufferBuilder); + if(cdsForArr != null){ + cdsFor = FlatNode.createControlDepForVector(bufferBuilder, cdsForArr); + } + + int flatNode = FlatNode.createFlatNode( bufferBuilder, ownId, @@ -867,12 +890,26 @@ public class FlatBuffersMapper { outVarNamesOffset, opNameOffset, outTypesOffset, //Output types - scalar + scalar, + opCds, + varCds, + cdsFor ); return flatNode; } + public static int[] mapOrNull(List list, FlatBufferBuilder fbb){ + if(list == null) + return null; + int[] out = new int[list.size()]; + int i=0; + for(String s : list){ + out[i++] = fbb.createString(s); + } + return out; + } + public static DifferentialFunction cloneViaSerialize(SameDiff sd, DifferentialFunction df ){ Map nameToIdxMap = new HashMap<>(); int count = 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java index b8625afde..336ec37d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java @@ -131,12 +131,12 @@ public class GradCheckUtil { // in this case, gradients of x and y are all 0 too //Collect variables to get gradients for - we want placeholders AND variables - Set gradVarNames = new HashSet<>(); + Set varsNeedingGrads = new HashSet<>(); for(Variable v : sd.getVariables().values()){ if(v.getVariable().dataType().isFPType() && (v.getVariable().getVariableType() == VariableType.VARIABLE || v.getVariable().getVariableType() == VariableType.PLACEHOLDER)){ SDVariable g = v.getVariable().getGradient(); Preconditions.checkNotNull(g, "No gradient variable found for variable %s", v.getVariable()); - gradVarNames.add(g.getVarName()); + varsNeedingGrads.add(v.getName()); } } @@ -164,7 +164,7 @@ public class GradCheckUtil { } - sd.execBackwards(placeholderValues, new ArrayList<>(gradVarNames)); + Map gm = sd.calculateGradients(placeholderValues, varsNeedingGrads); //Remove listener, to reduce overhead sd.getListeners().remove(listenerIdx); @@ -183,11 +183,11 @@ public class GradCheckUtil { if(g == null){ throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\""); } - INDArray ga = g.getArr(); + INDArray ga = gm.get(v.getVarName()); if(ga == null){ throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName()); } - if(!Arrays.equals(v.getArr().shape(), g.getArr().shape())){ + if(!Arrays.equals(v.getArr().shape(), ga.shape())){ throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " + Arrays.toString(ga.shape())); @@ -408,18 +408,18 @@ public class GradCheckUtil { //Collect names of variables to get gradients for - i.e., the names of the GRADIENT variables for the specified activations sd.createGradFunction(); - Set gradVarNames = new HashSet<>(); + Set varsRequiringGrads = new HashSet<>(); for(String s : actGrads){ SDVariable grad = sd.getVariable(s).gradient(); Preconditions.checkState( grad != null,"Could not get gradient for activation \"%s\": gradient variable is null", s); - gradVarNames.add(grad.getVarName()); + varsRequiringGrads.add(s); } //Calculate analytical gradients - sd.execBackwards(config.getPlaceholderValues(), new ArrayList<>(gradVarNames)); + Map grads = sd.calculateGradients(config.getPlaceholderValues(), new ArrayList<>(varsRequiringGrads)); Map gradientsForAct = new HashMap<>(); for(String s : actGrads){ - INDArray arr = sd.getVariable(s).gradient().getArr(); + INDArray arr = grads.get(s); Preconditions.checkState(arr != null, "No activation gradient array for variable \"%s\"", s); gradientsForAct.put(s, arr.dup()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index 74c1d868d..218004c67 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -190,11 +190,13 @@ public class OpValidation { //Check forward pass: if (testCase.fwdTestFns() != null && testCase.fwdTestFns().size() > 0) { SameDiff sd = testCase.sameDiff(); + + //Collect variables we need outputs for... + Set reqVars = testCase.fwdTestFns().keySet(); + + Map out; try { - if(testCase.placeholderValues() != null){ - sd.resolveVariablesWith(testCase.placeholderValues()); - } - sd.exec(null, sd.outputs()); + out = sd.output(testCase.placeholderValues(), new ArrayList<>(reqVars)); } catch (Exception e) { throw new RuntimeException("Error during forward pass testing" + testCase.testNameErrMsg(), e); } @@ -206,7 +208,7 @@ public class OpValidation { e.getKey() + "\" but SameDiff instance does not have a variable for this name" + testCase.testNameErrMsg()); } - INDArray actual = v.getArr(); + INDArray actual = out.get(v.getVarName()); if (actual == null) { throw new IllegalStateException("Null INDArray after forward pass for variable \"" + e.getKey() + "\""); } @@ -291,6 +293,12 @@ public class OpValidation { Preconditions.checkState((orig.getControlDeps() == null) == (des.getControlDeps() == null), "Control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps()); Preconditions.checkState(orig.getControlDeps() == null || orig.getControlDeps().equals(des.getControlDeps()), "Control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps()); + Preconditions.checkState((orig.getVarControlDeps() == null) == (des.getVarControlDeps() == null), "Op variable control dependencies differ: %s vs. %s", orig.getVarControlDeps(), des.getVarControlDeps()); + Preconditions.checkState(orig.getVarControlDeps() == null || orig.getVarControlDeps().equals(des.getVarControlDeps()), "Op variable control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps()); + + Preconditions.checkState((orig.getControlDepFor() == null) == (des.getControlDepFor() == null), "Op control dependencies for list differ: %s vs. %s", orig.getControlDepFor(), des.getControlDepFor()); + Preconditions.checkState(orig.getControlDepFor() == null || orig.getControlDepFor().equals(des.getControlDepFor()), "Op variable control dependencies differ: %s vs. %s", orig.getControlDepFor(), des.getControlDepFor()); + Preconditions.checkState(orig.getOp().getClass() == des.getOp().getClass(), "Classes differ: %s v. %s", orig.getOp().getClass(), des.getOp().getClass()); } @@ -317,6 +325,11 @@ public class OpValidation { Map varsBefore = original.getVariables(); Map varsAfter = deserialized.getVariables(); Preconditions.checkState(varsBefore.keySet().equals(varsAfter.keySet()), "Variable keysets do not match: %s vs %s", varsBefore.keySet(), varsAfter.keySet()); + +// System.out.println(original.summary()); +// System.out.println("\n\n\n\n"); +// System.out.println(deserialized.summary()); + for(String s : varsBefore.keySet()){ Variable vB = varsBefore.get(s); Variable vA = varsAfter.get(s); @@ -324,13 +337,15 @@ public class OpValidation { Preconditions.checkState(vB.getVariable().getVariableType() == vA.getVariable().getVariableType(), "Variable types do not match: %s - %s vs %s", s, vB.getVariable().getVariableType(), vA.getVariable().getVariableType()); - Preconditions.checkState((vB.getInputsForOp() == null) == (vA.getInputsForOp() == null), "Input to ops differ: %s vs. %s", vB.getInputsForOp(), vA.getInputsForOp()); - Preconditions.checkState(vB.getInputsForOp() == null || vB.getInputsForOp().equals(vA.getInputsForOp()), "Inputs differ: %s vs. %s", vB.getInputsForOp(), vA.getInputsForOp()); + equalConsideringNull(vB.getInputsForOp(), vA.getInputsForOp(), "%s - Input to ops differ: %s vs. %s", s, vB.getInputsForOp(), vA.getInputsForOp()); - Preconditions.checkState((vB.getOutputOfOp() == null && vA.getOutputOfOp() == null) || vB.getOutputOfOp().equals(vA.getOutputOfOp()), "Output of op differ: %s vs. %s", vB.getOutputOfOp(), vA.getOutputOfOp()); + Preconditions.checkState((vB.getOutputOfOp() == null && vA.getOutputOfOp() == null) || vB.getOutputOfOp().equals(vA.getOutputOfOp()), "%s - Output of op differ: %s vs. %s", s, vB.getOutputOfOp(), vA.getOutputOfOp()); - Preconditions.checkState((vB.getControlDeps() == null) == (vA.getControlDeps() == null), "Control dependencies differ: %s vs. %s", vB.getControlDeps(), vA.getControlDeps()); - Preconditions.checkState(vB.getControlDeps() == null || vB.getControlDeps().equals(vA.getControlDeps()), "Control dependencies differ: %s vs. %s", vB.getControlDeps(), vA.getControlDeps()); + equalConsideringNull(vB.getControlDeps(), vA.getControlDeps(), "%s - Control dependencies differ: %s vs. %s", s, vB.getControlDeps(), vA.getControlDeps()); + + equalConsideringNull(vB.getControlDepsForOp(), vA.getControlDepsForOp(), "%s - Control dependencies for ops differ: %s vs. %s", s, vB.getControlDepsForOp(), vA.getControlDepsForOp()); + + equalConsideringNull(vB.getControlDepsForVar(), vA.getControlDepsForVar(), "%s - Control dependencies for vars differ: %s vs. %s", s, vB.getControlDepsForVar(), vA.getControlDepsForVar()); } //Check loss variables: @@ -343,51 +358,62 @@ public class OpValidation { lossVarBefore, lossVarAfter); } + if(tc.fwdTestFns() != null && !tc.fwdTestFns().isEmpty()) { + //Finally: check execution/output + Map outOrig = original.outputAll(tc.placeholderValues()); + Map outDe = deserialized.outputAll(tc.placeholderValues()); + Preconditions.checkState(outOrig.keySet().equals(outDe.keySet()), "Keysets for execution after deserialization does not match key set for original model"); - //Finally: check execution/output - Map outOrig = original.outputAll(tc.placeholderValues()); - Map outDe = deserialized.outputAll(tc.placeholderValues()); - Preconditions.checkState(outOrig.keySet().equals(outDe.keySet()), "Keysets for execution after deserialization does not match key set for original model"); + for (String s : outOrig.keySet()) { + INDArray orig = outOrig.get(s); + INDArray deser = outDe.get(s); - for(String s : outOrig.keySet()){ - INDArray orig = outOrig.get(s); - INDArray deser = outDe.get(s); - - Function f = tc.fwdTestFns().get(s); - String err = null; - if(f != null){ - err = f.apply(deser); - } else { - if(!orig.equals(deser)){ - //Edge case: check for NaNs in original and deserialized... might be legitimate test (like replaceNaNs op) - long count = orig.dataType().isNumerical() ? Nd4j.getExecutioner().execAndReturn(new MatchCondition(orig, Conditions.isNan())).getFinalResult().longValue() : -1; - if(orig.dataType().isNumerical() && count > 0 && orig.equalShapes(deser)){ - long count2 = Nd4j.getExecutioner().execAndReturn(new MatchCondition(deser, Conditions.isNan())).getFinalResult().longValue(); - if(count != count2){ - err = "INDArray equality failed"; - } else { - //TODO is there a better way to do this? - NdIndexIterator iter = new NdIndexIterator(orig.shape()); - while(iter.hasNext()){ - long[] i = iter.next(); - double d1 = orig.getDouble(i); - double d2 = deser.getDouble(i); - if((Double.isNaN(d1) != Double.isNaN(d2)) || (Double.isInfinite(d1) != Double.isInfinite(d2)) || Math.abs(d1 - d2) > 1e-5 ){ - err = "INDArray equality failed"; - break; + Function f = tc.fwdTestFns().get(s); + String err = null; + if (f != null) { + err = f.apply(deser); + } else { + if (!orig.equals(deser)) { + //Edge case: check for NaNs in original and deserialized... might be legitimate test (like replaceNaNs op) + long count = orig.dataType().isNumerical() ? Nd4j.getExecutioner().execAndReturn(new MatchCondition(orig, Conditions.isNan())).getFinalResult().longValue() : -1; + if (orig.dataType().isNumerical() && count > 0 && orig.equalShapes(deser)) { + long count2 = Nd4j.getExecutioner().execAndReturn(new MatchCondition(deser, Conditions.isNan())).getFinalResult().longValue(); + if (count != count2) { + err = "INDArray equality failed"; + } else { + //TODO is there a better way to do this? + NdIndexIterator iter = new NdIndexIterator(orig.shape()); + while (iter.hasNext()) { + long[] i = iter.next(); + double d1 = orig.getDouble(i); + double d2 = deser.getDouble(i); + if ((Double.isNaN(d1) != Double.isNaN(d2)) || (Double.isInfinite(d1) != Double.isInfinite(d2)) || Math.abs(d1 - d2) > 1e-5) { + err = "INDArray equality failed"; + break; + } } } + } else { + err = "INDArray equality failed"; } - } else { - err = "INDArray equality failed"; } } - } - Preconditions.checkState(err == null, "Variable result (%s) failed check - \"%ndSInfo\" vs \"%ndSInfo\" - %nd10 vs %nd10\nError:%s", s, orig, deser, orig, deser, err); + Preconditions.checkState(err == null, "Variable result (%s) failed check - \"%ndSInfo\" vs \"%ndSInfo\" - %nd10 vs %nd10\nError:%s", s, orig, deser, orig, deser, err); + } } } + protected static void equalConsideringNull(List l1, List l2, String msg, Object... args){ + //Consider null and length 0 list to be equal (semantically they mean the same thing) + boolean empty1 = l1 == null || l1.isEmpty(); + boolean empty2 = l2 == null || l2.isEmpty(); + if(empty1 && empty2){ + return; + } + Preconditions.checkState(l1 == null || l1.equals(l2), msg, args); + } + /** * Validate the outputs of a single op * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java index f5ae0693d..7e7a50ab2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java @@ -25,6 +25,7 @@ public class NonInplaceValidationListener extends BaseListener { private static AtomicInteger failCounter = new AtomicInteger(); protected INDArray[] opInputs; + protected INDArray[] opInputsOrig; public NonInplaceValidationListener(){ useCounter.getAndIncrement(); @@ -42,14 +43,18 @@ public class NonInplaceValidationListener extends BaseListener { //No input op return; } else if(o.y() == null){ + opInputsOrig = new INDArray[]{o.x()}; opInputs = new INDArray[]{o.x().dup()}; } else { + opInputsOrig = new INDArray[]{o.x(), o.y()}; opInputs = new INDArray[]{o.x().dup(), o.y().dup()}; } } else if(op.getOp() instanceof DynamicCustomOp){ INDArray[] arr = ((DynamicCustomOp) op.getOp()).inputArguments(); opInputs = new INDArray[arr.length]; + opInputsOrig = new INDArray[arr.length]; for( int i=0; i= 0; i--) builder.addByte(data[i]); return builder.endVector(); } public static void startOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); } public static void addScalar(FlatBufferBuilder builder, int scalarOffset) { builder.addOffset(18, scalarOffset, 0); } + public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(19, controlDepsOffset, 0); } + public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addVarControlDeps(FlatBufferBuilder builder, int varControlDepsOffset) { builder.addOffset(20, varControlDepsOffset, 0); } + public static int createVarControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); } + public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endFlatNode(FlatBufferBuilder builder) { int o = builder.endObject(); return o; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatVariable.java index 76335c1ae..4845f7320 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatVariable.java @@ -29,16 +29,28 @@ public final class FlatVariable extends Table { public FlatArray ndarray(FlatArray obj) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; } public int device() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; } public byte variabletype() { int o = __offset(16); return o != 0 ? bb.get(o + bb_pos) : 0; } + public String controlDeps(int j) { int o = __offset(18); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepsLength() { int o = __offset(18); return o != 0 ? __vector_len(o) : 0; } + public String controlDepForOp(int j) { int o = __offset(20); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepForOpLength() { int o = __offset(20); return o != 0 ? __vector_len(o) : 0; } + public String controlDepsForVar(int j) { int o = __offset(22); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepsForVarLength() { int o = __offset(22); return o != 0 ? __vector_len(o) : 0; } public static int createFlatVariable(FlatBufferBuilder builder, - int idOffset, - int nameOffset, - byte dtype, - int shapeOffset, - int ndarrayOffset, - int device, - byte variabletype) { - builder.startObject(7); + int idOffset, + int nameOffset, + byte dtype, + int shapeOffset, + int ndarrayOffset, + int device, + byte variabletype, + int controlDepsOffset, + int controlDepForOpOffset, + int controlDepsForVarOffset) { + builder.startObject(10); + FlatVariable.addControlDepsForVar(builder, controlDepsForVarOffset); + FlatVariable.addControlDepForOp(builder, controlDepForOpOffset); + FlatVariable.addControlDeps(builder, controlDepsOffset); FlatVariable.addDevice(builder, device); FlatVariable.addNdarray(builder, ndarrayOffset); FlatVariable.addShape(builder, shapeOffset); @@ -49,7 +61,7 @@ public final class FlatVariable extends Table { return FlatVariable.endFlatVariable(builder); } - public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(7); } + public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(10); } public static void addId(FlatBufferBuilder builder, int idOffset) { builder.addOffset(0, idOffset, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addDtype(FlatBufferBuilder builder, byte dtype) { builder.addByte(2, dtype, 0); } @@ -59,6 +71,15 @@ public final class FlatVariable extends Table { public static void addNdarray(FlatBufferBuilder builder, int ndarrayOffset) { builder.addOffset(4, ndarrayOffset, 0); } public static void addDevice(FlatBufferBuilder builder, int device) { builder.addInt(5, device, 0); } public static void addVariabletype(FlatBufferBuilder builder, byte variabletype) { builder.addByte(6, variabletype, 0); } + public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(7, controlDepsOffset, 0); } + public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addControlDepForOp(FlatBufferBuilder builder, int controlDepForOpOffset) { builder.addOffset(8, controlDepForOpOffset, 0); } + public static int createControlDepForOpVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addControlDepsForVar(FlatBufferBuilder builder, int controlDepsForVarOffset) { builder.addOffset(9, controlDepsForVarOffset, 0); } + public static int createControlDepsForVarVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endFlatVariable(FlatBufferBuilder builder) { int o = builder.endObject(); return o; @@ -67,3 +88,4 @@ public final class FlatVariable extends Table { public static void finishSizePrefixedFlatVariableBuffer(FlatBufferBuilder builder, int offset) { builder.finishSizePrefixed(offset); } } + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java index 82bfdc843..05ac2495c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java @@ -25,11 +25,7 @@ import org.nd4j.imports.descriptors.onnx.OnnxDescriptorParser; import org.nd4j.imports.descriptors.onnx.OpDescriptor; import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser; import org.nd4j.linalg.api.ops.*; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.convolution.*; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -370,6 +366,8 @@ public class DifferentialFunctionClassHolder { return Merge.class; case Switch.OP_NAME: return Switch.class; + case LoopCond.OP_NAME: + return LoopCond.class; case ExternalErrorsFunction.OP_NAME: return ExternalErrorsFunction.class; default: diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 2fd2e6332..7bdea70b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -69,13 +69,9 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThanOrEqual.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastNotEqual.class, - org.nd4j.linalg.api.ops.impl.controlflow.If.class, - org.nd4j.linalg.api.ops.impl.controlflow.IfDerivative.class, org.nd4j.linalg.api.ops.impl.controlflow.Select.class, org.nd4j.linalg.api.ops.impl.controlflow.Where.class, org.nd4j.linalg.api.ops.impl.controlflow.WhereNumpy.class, - org.nd4j.linalg.api.ops.impl.controlflow.While.class, - org.nd4j.linalg.api.ops.impl.controlflow.WhileDerivative.class, org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter.class, org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit.class, org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java deleted file mode 100644 index 95f238973..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java +++ /dev/null @@ -1,413 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.imports.graphmapper; - -import org.nd4j.linalg.util.ArrayUtil; -import org.nd4j.shade.protobuf.Message; -import org.nd4j.shade.protobuf.TextFormat; -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.apache.commons.io.IOUtils; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.VariableType; -import org.nd4j.autodiff.samediff.internal.SameDiffOp; -import org.nd4j.autodiff.samediff.internal.Variable; -import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.Op; -import org.nd4j.linalg.exception.ND4JIllegalStateException; - -import java.io.*; -import java.util.*; - -/** - * Base implementation for importing a graph - * - * @param the type of graph - * @param the type of node - * @param the attribute type - */ -@Slf4j -public abstract class BaseGraphMapper implements GraphMapper { - - - @Override - public Op.Type opTypeForNode(NODE_TYPE nodeDef) { - DifferentialFunction opWithTensorflowName = getMappedOp(getOpType(nodeDef)); - if (opWithTensorflowName == null) - throw new NoOpNameFoundException("No op found with name " + getOpType(nodeDef)); - Op.Type type = opWithTensorflowName.opType(); - return type; - - } - - - @Override - public void mapProperties(DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map> propertyMappings) { - val mappings = propertyMappings.get(getOpType(node)); - if (mappings == null || mappings.isEmpty()) { - return; - } - - - for (val entry : mappings.entrySet()) { - mapProperty(entry.getKey(), on, node, graph, sameDiff, propertyMappings); - } - } - - - /** - * @param inputStream - * @return - */ - @Override - public SameDiff importGraph(InputStream inputStream) { - return importGraph(inputStream, Collections.>emptyMap(), null); - } - - @Override - public SameDiff importGraph(InputStream inputStream, Map> opImportOverrides, - OpImportFilter opFilter) { - GRAPH_TYPE def = readGraph(inputStream, opImportOverrides); - return importGraph(def, opImportOverrides, opFilter); - } - - protected GRAPH_TYPE readGraph(InputStream inputStream, Map> opImportOverrides) { - byte[] bytes = null; - GRAPH_TYPE def = null; - try { - bytes = IOUtils.toByteArray(inputStream); //Buffers internally - def = parseGraphFrom(bytes); - } catch (IOException e) { - try (BufferedInputStream bis2 = new BufferedInputStream(new ByteArrayInputStream(bytes)); BufferedReader reader = new BufferedReader(new InputStreamReader(bis2))) { - Message.Builder builder = getNewGraphBuilder(); - - StringBuilder str = new StringBuilder(); - String line = null; - while ((line = reader.readLine()) != null) { - str.append(line);//.append("\n"); - } - - TextFormat.getParser().merge(str.toString(), builder); - def = (GRAPH_TYPE) builder.build(); - } catch (Exception e2) { - e2.printStackTrace(); - } - } - - return def; - } - - /** - * @param graphFile - * @return - */ - @Override - public SameDiff importGraph(File graphFile) { - return importGraph(graphFile, Collections.>emptyMap(), null); - } - - @Override - public SameDiff importGraph(File graphFile, Map> opImportOverrides, - OpImportFilter opFilter) { - GRAPH_TYPE def = null; - try (FileInputStream fis = new FileInputStream(graphFile)) { - return importGraph(fis, opImportOverrides, opFilter); - } catch (Exception e) { - throw new ND4JIllegalStateException("Error encountered loading graph file: " + graphFile.getAbsolutePath(), e); - } - } - - @Override - public Map nameIndexForGraph(GRAPH_TYPE graph) { - List nodes = getNodeList(graph); - Map ret = new HashMap<>(); - for (NODE_TYPE node : nodes) { - ret.put(getName(node), node); - } - return ret; - } - - @Override - public Map nodesByName(GRAPH_TYPE graph) { - val nodeTypes = getNodeList(graph); - Map ret = new LinkedHashMap<>(); - for (int i = 0; i < nodeTypes.size(); i++) { - ret.put(getName(nodeTypes.get(i)), nodeTypes.get(i)); - } - return ret; - } - - /** - * This method converts given TF - * - * @param tfGraph - * @return - */ - @Override - public SameDiff importGraph(GRAPH_TYPE tfGraph) { - return importGraph(tfGraph, Collections.>emptyMap(), null); - } - - @Override - public SameDiff importGraph(GRAPH_TYPE tfGraph, Map> opImportOverrides, - OpImportFilter opFilter) { - - SameDiff diff = SameDiff.create(); - ImportState importState = new ImportState<>(); - importState.setSameDiff(diff); - importState.setGraph(tfGraph); - - Map variablesForGraph = variablesForGraph(tfGraph); - importState.setVariables(variablesForGraph); - - - //Add each of the variables first - before importing ops - Map stringNodes = new HashMap<>(); //Key: name of string variable. Value: if it's a constant - for (Map.Entry entry : variablesForGraph.entrySet()) { - if (shouldSkip((NODE_TYPE) entry.getValue())) { //TODO only works for TF - //Skip some nodes, for example reduction indices (a lot of ND4J/SameDiff ops use int[] etc, not an INDArray/Variable) - continue; - } - - //First: check if we're skipping the op entirely. If so: don't create the output variables for it. - NODE_TYPE node = (NODE_TYPE) entry.getValue(); //TODO this only works for TF - String opType = getOpType(node); - String opName = getName(node); - if(opFilter != null && opFilter.skipOp(node, importState.getSameDiff(), null, importState.getGraph() )){ - log.info("Skipping variables for op: {} (name: {})", opType, opName); - continue; - } - - //Similarly, if an OpImportOverride is defined, don't create the variables now, as these might be the wrong type - //For example, the OpImportOverride might replace the op with some placeholders - // If we simply created them now, we'd create the wrong type (Array not placeholder) - if(opImportOverrides != null && opImportOverrides.containsKey(opType)){ - log.info("Skipping variables for op due to presence of OpImportOverride: {} (name: {})", opType, opName); - continue; - } - - - DataType dt = dataTypeForTensor(entry.getValue(), 0); - INDArray arr = getNDArrayFromTensor(entry.getKey(), entry.getValue(), tfGraph); - long[] shape = hasShape((NODE_TYPE) entry.getValue()) ? getShape((NODE_TYPE) entry.getValue()) : null; //TODO only works for TF - - //Not all variables have datatypes available on import - we have to infer these at a later point - // so we'll leave datatypes as null and infer them once all variables/ops have been imported - if(dt == DataType.UNKNOWN) - dt = null; - - if (isPlaceHolder(entry.getValue())) { - diff.placeHolder(entry.getKey(), dt, shape); - } else if (isConstant(entry.getValue())) { - Preconditions.checkNotNull(arr, "Array is null for placeholder variable %s", entry.getKey()); - diff.constant(entry.getKey(), arr); - } else { - //Could be variable, or could be array type (i.e., output of op/"activations") - //TODO work out which! - - SDVariable v; - if(shape == null || ArrayUtil.contains(shape, 0)){ - //No shape, or 0 in shape -> probably not a variable... - v = diff.var(entry.getKey(), VariableType.ARRAY, null, dt, (long[])null); - } else { - v = diff.var(entry.getKey(), dt, shape); - } - if (arr != null) - diff.associateArrayWithVariable(arr, v); - } - -// NODE_TYPE node = (NODE_TYPE) entry.getValue(); //TODO this only works for TF - List controlDependencies = getControlDependencies(node); - if (controlDependencies != null) { - Variable v = diff.getVariables().get(entry.getKey()); - v.setControlDeps(controlDependencies); - } - } - - //Map ops - val tfNodesList = getNodeList(tfGraph); - for (NODE_TYPE node : tfNodesList) { - String opType = getOpType(node); - OpImportOverride importOverride = null; - if(opImportOverrides != null){ - importOverride = opImportOverrides.get(opType); - } - - if(opFilter != null && opFilter.skipOp(node, importState.getSameDiff(), null, null)){ - String opName = getName(node); - log.info("Skipping op due to op filter: {}", opType, opName); - continue; - } - - if (!opsToIgnore().contains(opType) || isOpIgnoreException(node)) { - mapNodeType(node, importState, importOverride, opFilter); - } - } - - - /* - At this point, we have a few remaining things to do: - 1. Make sure all datatypes are set on all variables. TF doesn't have datatype info an all op outputs for some reason, so we have to infer in manually - 2. Make sure all op output variables have been created - 3. Make sure all SameDiffOp.outputsOfOp is set - 4. Make sure all Variable.outputOfOp is set - 5. Make sure all Variable.controlDepsForVar have been populated (reverse lookup of Variable.controlDeps) - */ - - //Make sure Variable.outputOfOp is set - for(Variable v : diff.getVariables().values()){ - if(v.getVariable().isPlaceHolder() || v.getVariable().isConstant()) - continue; - - //Expect variable names of output variables to be: opName, opName:1, opName:2, etc - String n = v.getName(); - String opName = n; - if(v.getName().matches(".*:\\d+")){ - //i.e., "something:2" - int idx = n.lastIndexOf(':'); - opName = n.substring(0,idx); - } - - if(diff.getOps().containsKey(opName)) { - //Variable is the output of an op - v.setOutputOfOp(opName); - - //Also double check variable type... - if(v.getVariable().getVariableType() != VariableType.ARRAY) - v.getVariable().setVariableType(VariableType.ARRAY); - } - } - - //Initialize any missing output variables - for (SameDiffOp op : diff.getOps().values()) { - DifferentialFunction df = op.getOp(); - initOutputVariables(diff, df); - } - - //Make sure all Variable.controlDepsForVar have been populated (reverse lookup of Variable.controlDeps) - //i.e., if control dependency x -> y exists, then: - // (a) x.controlDepsForVar should contain "y" - // (b) y.controlDeps should contain "x" - //Need to do this before output datatype calculation, as these control dep info is used in sessions - for(Map.Entry e : diff.getVariables().entrySet()){ - Variable v = e.getValue(); - if(v.getControlDeps() != null){ - for(String s : v.getControlDeps()){ - Variable v2 = diff.getVariables().get(s); - if(v2.getControlDepsForVar() == null) - v2.setControlDepsForVar(new ArrayList()); - if(!v2.getControlDepsForVar().contains(e.getKey())){ - //Control dep v2 -> v exists, so put v.name into v2.controlDepsForVar - v2.getControlDepsForVar().add(e.getKey()); - } - } - } - } - - //Same thing for op control dependencies... - for(Map.Entry e : diff.getOps().entrySet()){ - SameDiffOp op = e.getValue(); - if(op.getControlDeps() != null){ - for(String s : op.getControlDeps()){ - //Control dependency varS -> op exists - Variable v = diff.getVariables().get(s); - if(v.getControlDepsForOp() == null) - v.setControlDepsForOp(new ArrayList()); - if(!v.getControlDepsForOp().contains(e.getKey())) - v.getControlDepsForOp().add(e.getKey()); - } - } - } - - - //Infer variable datatypes to ensure all variables have datatypes... - boolean anyUnknown = false; - for(SDVariable v : diff.variables()){ - if(v.dataType() == null) - anyUnknown = true; - } - if(anyUnknown){ - Map dataTypes = diff.calculateOutputDataTypes(); - for(SDVariable v : diff.variables()){ - if(v.dataType() == null){ - v.setDataType(dataTypes.get(v.getVarName())); - } - } - } - - //Validate the graph structure - validateGraphStructure(diff); - - return diff; - } - - protected void initOutputVariables(SameDiff sd, DifferentialFunction df) { - String[] outNames = sd.getOutputsForOp(df); - SDVariable[] outVars; - if (outNames == null) { - outVars = sd.generateOutputVariableForOp(df, df.getOwnName() != null ? df.getOwnName() : df.opName(), true); - outNames = new String[outVars.length]; - for (int i = 0; i < outVars.length; i++) { - outNames[i] = outVars[i].getVarName(); - } - sd.getOps().get(df.getOwnName()).setOutputsOfOp(Arrays.asList(outNames)); - } - - for (String s : outNames) { - sd.getVariables().get(s).setOutputOfOp(df.getOwnName()); - } - } - - - @Override - public boolean validTensorDataType(TENSOR_TYPE tensorType) { - return dataTypeForTensor(tensorType, 0) != DataType.UNKNOWN; - } - - public void validateGraphStructure(SameDiff sameDiff) { - //First: Check placeholders. When SDVariables are added with null shapes, these can be interpreted as a placeholder - // but null shapes might simply mean shape isn't available during import right when the variable is added - //Idea here: if a "placeholder" is the output of any function, it's not really a placeholder - for (SDVariable v : sameDiff.variables()) { - String name = v.getVarName(); - if (sameDiff.isPlaceHolder(name)) { - String varOutputOf = sameDiff.getVariables().get(name).getOutputOfOp(); - } - } - - //Second: check that all op inputs actually exist in the graph - for (SameDiffOp op : sameDiff.getOps().values()) { - List inputs = op.getInputsToOp(); - if (inputs == null) - continue; - - for (String s : inputs) { - if (sameDiff.getVariable(s) == null) { - throw new IllegalStateException("Import validation failed: op \"" + op.getName() + "\" of type " + op.getOp().getClass().getSimpleName() - + " has input \"" + s + "\" that does not have a corresponding variable in the graph"); - } - } - } - } - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/GraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/GraphMapper.java deleted file mode 100644 index 2d89a2b07..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/GraphMapper.java +++ /dev/null @@ -1,429 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.imports.graphmapper; - -import org.nd4j.shade.protobuf.Message; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.Op; - -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.util.List; -import java.util.Map; -import java.util.Set; - -/** - * Map graph proto types to - * - * {@link SameDiff} instances - * @param the proto type for the graph - * @param the proto type for the node - * @param the proto type for the attribute - * @param the proto type for the tensor - *@author Adam Gibson - */ -public interface GraphMapper { - - /** - * Import a graph as SameDiff from the given file - * @param graphFile Input stream pointing to graph file to import - * @return Imported graph - */ - SameDiff importGraph(InputStream graphFile); - - SameDiff importGraph(InputStream graphFile, Map> opImportOverrides, - OpImportFilter opFilter); - - /** - * Import a graph as SameDiff from the given file - * @param graphFile Graph file to import - * @return Imported graph - * @see #importGraph(File, Map) - */ - SameDiff importGraph(File graphFile); - - /** - * Import a graph as SameDiff from the given file, with optional op import overrides.
- * The {@link OpImportOverride} instances allow the operation import to be overridden - useful for importing ops - * that have not been mapped for import in SameDiff yet, and also for non-standard/user-defined functions. - * - * @param graphFile Graph file to import - * @param opImportOverrides May be null. If non-null: used to import the specified operations. Key is the name of the - * operation to import, value is the object used to import it - * @return Imported graph - */ - SameDiff importGraph(File graphFile, Map> opImportOverrides, - OpImportFilter opFilter); - - /** - * This method converts given graph type (in its native format) to SameDiff - * @param graph Graph to import - * @return Imported graph - */ - SameDiff importGraph(GRAPH_TYPE graph); - - /** - * This method converts given graph type (in its native format) to SameDiff
- * The {@link OpImportOverride} instances allow the operation import to be overridden - useful for importing ops - * that have not been mapped for import in SameDiff yet, and also for non-standard/user-defined functions. - * @param graph Graph to import - * @return Imported graph - */ - SameDiff importGraph(GRAPH_TYPE graph, Map> opImportOverrides, - OpImportFilter opFilter); - - - /** - * Returns true if this node is a special case - * (maybe because of name or other scenarios) - * that should override {@link #opsToIgnore()} - * in certain circumstances - * @param node the node to check - * @return true if this node is an exception false otherwise - */ - boolean isOpIgnoreException(NODE_TYPE node); - - /** - * Get the nodes sorted by n ame - * from a given graph - * @param graph the graph to get the nodes for - * @return the map of the nodes by name - * for a given graph - */ - Map nodesByName(GRAPH_TYPE graph); - - /** - * Get the target mapping key (usually based on the node name) - * for the given function - * @param function the function - * @param node the node to derive the target mapping from - * @return - */ - String getTargetMappingForOp(DifferentialFunction function, NODE_TYPE node); - - - /** - * - * @param on - * @param node - * @param graph - * @param sameDiff - * @param propertyMappings - */ - void mapProperties(DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map> propertyMappings); - - - /** - * - * @param name - * @param on - * @param node - * @param graph - * @param sameDiff - * @param propertyMappingsForFunction - */ - void mapProperty(String name, DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map> propertyMappingsForFunction); - - /** - * Get the node from the graph - * @param graph the graph to get the node from - * @param name the name of the node to get from the graph - * @return - */ - NODE_TYPE getNodeWithNameFromGraph(GRAPH_TYPE graph,String name); - - /** - * Returns true if the given node is a place holder - * @param node the node to check - * @return true if the node is a place holder or not - */ - boolean isPlaceHolderNode(TENSOR_TYPE node); - - /** - * Get the list of control dependencies for the current node (or null if none exist) - * - * @param node Node to get the control dependencies (if any) for - * @return - */ - List getControlDependencies(NODE_TYPE node); - - /** - * Dump a binary proto file representation as a - * plain string in to the target text file - * @param inputFile - * @param outputFile - */ - void dumpBinaryProtoAsText(File inputFile,File outputFile); - - - /** - * Dump a binary proto file representation as a - * plain string in to the target text file - * @param inputFile - * @param outputFile - */ - void dumpBinaryProtoAsText(InputStream inputFile,File outputFile); - - - /** - * Get the mapped op name - * for a given op - * relative to the type of node being mapped. - * The input name should be based on a tensorflow - * type or onnx type, not the nd4j name - * @param name the tensorflow or onnx name - * @return the function based on the values in - * {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder} - */ - DifferentialFunction getMappedOp(String name); - - - /** - * Get the variables for the given graph - * @param graphType the graph to get the variables for - * @return a map of variable name to tensor - */ - Map variablesForGraph(GRAPH_TYPE graphType); - - /** - * - * @param name - * @param node - * @return - */ - String translateToSameDiffName(String name, NODE_TYPE node); - - - /** - * - * @param graph - * @return - */ - Map nameIndexForGraph(GRAPH_TYPE graph); - - /** - * Returns an op type for the given input node - * @param nodeType the node to use - * @return the optype for the given node - */ - Op.Type opTypeForNode(NODE_TYPE nodeType); - - /** - * Returns a graph builder for initial definition and parsing. - * @return - */ - Message.Builder getNewGraphBuilder(); - - /** - * Parse a graph from an input stream - * @param inputStream the input stream to load from - * @return - */ - GRAPH_TYPE parseGraphFrom(byte[] inputStream) throws IOException; - - /** - * Parse a graph from an input stream - * @param inputStream the input stream to load from - * @return - */ - GRAPH_TYPE parseGraphFrom(InputStream inputStream) throws IOException; - - - /** - * Map a node in to the import state covering the {@link SameDiff} instance - * @param tfNode the node to map - * @param importState the current import state - * @param opFilter Optional filter for skipping operations - */ - void mapNodeType(NODE_TYPE tfNode, ImportState importState, - OpImportOverride opImportOverride, - OpImportFilter opFilter); - - - /** - * - * @param tensorType - * @param outputNum - * @return - */ - DataType dataTypeForTensor(TENSOR_TYPE tensorType, int outputNum); - - boolean isStringType(TENSOR_TYPE tensor); - - /** - * - * @param nodeType - * @param key - * @return - */ - String getAttrValueFromNode(NODE_TYPE nodeType,String key); - - - /** - * - * @param attrType - * @return - */ - long[] getShapeFromAttribute(ATTR_TYPE attrType); - - /** - * Returns true if the given node is a place holder type - * (think a yet to be determined shape)_ - * @param nodeType - * @return - */ - boolean isPlaceHolder(TENSOR_TYPE nodeType); - - - /** - * Returns true if the given node is a constant - * @param nodeType - * @return - */ - boolean isConstant(TENSOR_TYPE nodeType); - - /** - * - * - * @param tensorName - * @param tensorType - * @param graph - * @return - */ - INDArray getNDArrayFromTensor(String tensorName, TENSOR_TYPE tensorType, GRAPH_TYPE graph); - - - /** - * Get the shape for the given tensor type - * @param tensorType - * @return - */ - long[] getShapeFromTensor(TENSOR_TYPE tensorType); - - - /** - * Ops to ignore for mapping - * @return - */ - Set opsToIgnore(); - - /** - * Get the input node for the given node - * @param node the node - * @param index hte index - * @return - */ - String getInputFromNode(NODE_TYPE node, int index); - - /** - * Get the number of inputs for a node. - * @param nodeType the node to get the number of inputs for - * @return - */ - int numInputsFor(NODE_TYPE nodeType); - - /** - * Whether the data type for the tensor is valid - * for creating an {@link INDArray} - * @param tensorType the tensor proto to test - * @return - */ - boolean validTensorDataType(TENSOR_TYPE tensorType); - - - /** - * Get the shape of the attribute value - * @param attr the attribute value - * @return the shape of the attribute if any or null - */ - long[] getShapeFromAttr(ATTR_TYPE attr); - - /** - * Get the attribute - * map for given node - * @param nodeType the node - * @return the attribute map for the attribute - */ - Map getAttrMap(NODE_TYPE nodeType); - - /** - * Get the name of the node - * @param nodeType the node - * to get the name for - * @return - */ - String getName(NODE_TYPE nodeType); - - /** - * - * @param nodeType - * @return - */ - boolean alreadySeen(NODE_TYPE nodeType); - - /** - * - * @param nodeType - * @return - */ - boolean isVariableNode(NODE_TYPE nodeType); - - /** - * - * - * @param opType - * @return - */ - boolean shouldSkip(NODE_TYPE opType); - - /** - * - * @param nodeType - * @return - */ - boolean hasShape(NODE_TYPE nodeType); - - /** - * - * @param nodeType - * @return - */ - long[] getShape(NODE_TYPE nodeType); - - /** - * - * @param nodeType - * @param graph - * @return - */ - INDArray getArrayFrom(NODE_TYPE nodeType, GRAPH_TYPE graph); - - - String getOpType(NODE_TYPE nodeType); - - /** - * - * @param graphType - * @return - */ - List getNodeList(GRAPH_TYPE graphType); -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/ImportState.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/ImportState.java deleted file mode 100644 index 1246f66fa..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/ImportState.java +++ /dev/null @@ -1,31 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.imports.graphmapper; - -import lombok.Data; -import org.nd4j.autodiff.samediff.SameDiff; - -import java.util.Map; - -@Data -public class ImportState { - private SameDiff sameDiff; - private GRAPH_TYPE graph; - private Map variables; - - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java deleted file mode 100644 index 7a651fb88..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java +++ /dev/null @@ -1,652 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.imports.graphmapper.onnx; - -import org.nd4j.shade.protobuf.ByteString; -import org.nd4j.shade.protobuf.Message; -import org.nd4j.shade.guava.primitives.Floats; -import org.nd4j.shade.guava.primitives.Ints; -import org.nd4j.shade.guava.primitives.Longs; -import lombok.val; -import onnx.Onnx; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.imports.converters.DifferentialFunctionClassHolder; -import org.nd4j.imports.descriptors.properties.AttributeAdapter; -import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.graphmapper.BaseGraphMapper; -import org.nd4j.imports.graphmapper.ImportState; -import org.nd4j.imports.graphmapper.OpImportFilter; -import org.nd4j.imports.graphmapper.OpImportOverride; -import org.nd4j.linalg.api.buffer.DataBuffer; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.util.ArrayUtil; - -import java.io.*; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.util.*; - -/** - * A mapper for onnx graphs to - * {@link org.nd4j.autodiff.samediff.SameDiff} instances. - * - * @author Adam Gibson - */ -public class OnnxGraphMapper extends BaseGraphMapper { - private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper(); - - - public static OnnxGraphMapper getInstance() { - return INSTANCE; - } - - - @Override - public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) { - try { - Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(inputFile); - BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); - for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) { - bufferedWriter.write(node.toString() + "\n"); - } - - bufferedWriter.flush(); - bufferedWriter.close(); - - } catch (IOException e) { - e.printStackTrace(); - } - } - - - - /** - * Init a function's attributes - * @param mappedTfName the onnx name to pick (sometimes ops have multiple names - * @param on the function to map - * @param attributesForNode the attributes for the node - * @param node - * @param graph - */ - public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map attributesForNode, Onnx.NodeProto node, Onnx.GraphProto graph) { - val properties = on.mappingsForFunction(); - val tfProperties = properties.get(mappedTfName); - val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on); - val attributeAdapters = on.attributeAdaptersForFunction(); - for(val entry : tfProperties.entrySet()) { - val tfAttrName = entry.getValue().getTfAttrName(); - val currentField = fields.get(entry.getKey()); - - AttributeAdapter adapter = null; - if(tfAttrName != null) { - if(currentField == null) { - continue; - } - if(attributeAdapters != null && !attributeAdapters.isEmpty()) { - val mappers = attributeAdapters.get(on.tensorflowName()); - val adapterFor = mappers.get(entry.getKey()); - adapter = adapterFor; - } - - - if(attributesForNode.containsKey(tfAttrName)) { - val attr = attributesForNode.get(tfAttrName); - switch (attr.getType()) { - case STRING: - val setString = attr.getS().toStringUtf8(); - if(adapter != null) { - adapter.mapAttributeFor(setString,currentField,on); - } - else - on.setValueFor(currentField,setString); - break; - case INT: - val setInt = (int) attr.getI(); - if(adapter != null) { - adapter.mapAttributeFor(setInt,currentField,on); - } - else - on.setValueFor(currentField,setInt); - break; - case INTS: - val setList = attr.getIntsList(); - if(!setList.isEmpty()) { - val intList = Ints.toArray(setList); - if(adapter != null) { - adapter.mapAttributeFor(intList,currentField,on); - } - else - on.setValueFor(currentField,intList); - } - break; - case FLOATS: - val floatsList = attr.getFloatsList(); - if(!floatsList.isEmpty()) { - val floats = Floats.toArray(floatsList); - if(adapter != null) { - adapter.mapAttributeFor(floats,currentField,on); - } - - else - on.setValueFor(currentField,floats); - break; - } - break; - case TENSOR: - val tensorToGet = mapTensorProto(attr.getT()); - if(adapter != null) { - adapter.mapAttributeFor(tensorToGet,currentField,on); - } - else - on.setValueFor(currentField,tensorToGet); - break; - - } - } - } - - - } - } - - @Override - public boolean isOpIgnoreException(Onnx.NodeProto node) { - return false; - } - - @Override - public String getTargetMappingForOp(DifferentialFunction function, Onnx.NodeProto node) { - return function.opName(); - } - - - @Override - public void mapProperty(String name, DifferentialFunction on, Onnx.NodeProto node, Onnx.GraphProto graph, SameDiff sameDiff, Map> propertyMappingsForFunction) { - val mapping = propertyMappingsForFunction.get(name).get(getTargetMappingForOp(on, node)); - val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on); - /** - * Map ints and the like. Need to figure out how attribute mapping should work. - * - * - */ - - val propsForFunction = on.propertiesForFunction(); - - if(mapping.getTfAttrName() == null) { - int tfMappingIdx = mapping.getTfInputPosition(); - if(tfMappingIdx < 0) - tfMappingIdx += node.getInputCount(); - - val input = node.getInput(tfMappingIdx); - val inputNode = getInstance().getNodeWithNameFromGraph(graph,input); - INDArray arr = sameDiff.getArrForVarName(input); - val field = fields.get(mapping.getPropertyNames()[0]); - val type = field.getType(); - if(type.equals(int[].class)) { - try { - field.set(arr.data().asInt(),on); - } catch (IllegalAccessException e) { - e.printStackTrace(); - } - } - else if(type.equals(int.class) || type.equals(long.class) || type.equals(Long.class) || type.equals(Integer.class)) { - try { - field.set(arr.getInt(0),on); - } catch (IllegalAccessException e) { - e.printStackTrace(); - } - } - else if(type.equals(float.class) || type.equals(double.class) || type.equals(Float.class) || type.equals(Double.class)) { - try { - field.set(arr.getDouble(0),on); - } catch (IllegalAccessException e) { - e.printStackTrace(); - } - } - - - - /** - * Figure out whether it's an int array - * or a double array, or maybe a scalar. - */ - - } - else { - val tfMappingAttrName = mapping.getOnnxAttrName(); - val attr = getAttrMap(node).get(tfMappingAttrName); - val type = attr.getType(); - val field = fields.get(mapping.getPropertyNames()[0]); - - Object valueToSet = null; - switch(type) { - case INT: - valueToSet = attr.getI(); - break; - case FLOAT: - valueToSet = attr.getF(); - break; - case STRING: - valueToSet = attr.getF(); - break; - - } - - try { - field.set(valueToSet,on); - } catch (IllegalAccessException e) { - e.printStackTrace(); - } - - } - } - - - @Override - public Onnx.NodeProto getNodeWithNameFromGraph(Onnx.GraphProto graph, String name) { - for(int i = 0; i < graph.getNodeCount(); i++) { - val node = graph.getNode(i); - if(node.getName().equals(name)) - return node; - } - - return null; - } - - @Override - public boolean isPlaceHolderNode(Onnx.TypeProto.Tensor node) { - return false; - } - - @Override - public List getControlDependencies(Onnx.NodeProto node) { - throw new UnsupportedOperationException("Not yet implemented"); - } - - @Override - public void dumpBinaryProtoAsText(File inputFile, File outputFile) { - try { - Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile))); - BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); - for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) { - bufferedWriter.write(node.toString()); - } - - bufferedWriter.flush(); - bufferedWriter.close(); - - } catch (IOException e) { - e.printStackTrace(); - } - } - - - - - /** - * - * @param name the tensorflow or onnx name - * @return - */ - @Override - public DifferentialFunction getMappedOp(String name) { - return DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(name); - } - - - - @Override - public Map variablesForGraph(Onnx.GraphProto graphProto) { - /** - * Need to figure out why - * gpu_0/conv1_1 isn't present in VGG - */ - Map ret = new HashMap<>(); - for(int i = 0; i < graphProto.getInputCount(); i++) { - ret.put(graphProto.getInput(i).getName(),graphProto.getInput(i).getType().getTensorType()); - } - - for(int i = 0; i < graphProto.getOutputCount(); i++) { - ret.put(graphProto.getOutput(i).getName(),graphProto.getOutput(i).getType().getTensorType()); - } - - for(int i = 0; i < graphProto.getNodeCount(); i++) { - val node = graphProto.getNode(i); - val name = node.getName().isEmpty() ? String.valueOf(i) : node.getName(); - //add -1 as place holder value representing the shape needs to be filled in - if(!ret.containsKey(name)) { - addDummyTensor(name,ret); - } - - for(int j = 0; j < node.getInputCount(); j++) { - if(!ret.containsKey(node.getInput(j))) { - addDummyTensor(node.getInput(j),ret); - } - } - - - for(int j = 0; j < node.getOutputCount(); j++) { - if(!ret.containsKey(node.getOutput(j))) { - addDummyTensor(node.getOutput(j),ret); - } - } - } - - return ret; - } - - @Override - public String translateToSameDiffName(String name, Onnx.NodeProto node) { - return null; - } - - - protected void addDummyTensor(String name, Map to) { - Onnx.TensorShapeProto.Dimension dim = Onnx.TensorShapeProto.Dimension. - newBuilder() - .setDimValue(-1) - .build(); - Onnx.TypeProto.Tensor typeProto = Onnx.TypeProto.Tensor.newBuilder() - .setShape( - Onnx.TensorShapeProto.newBuilder() - .addDim(dim) - .addDim(dim).build()) - .build(); - to.put(name,typeProto); - } - - @Override - public Message.Builder getNewGraphBuilder() { - return Onnx.GraphProto.newBuilder(); - } - - @Override - public Onnx.GraphProto parseGraphFrom(byte[] inputStream) throws IOException { - return Onnx.ModelProto.parseFrom(inputStream).getGraph(); - } - - @Override - public Onnx.GraphProto parseGraphFrom(InputStream inputStream) throws IOException { - return Onnx.ModelProto.parseFrom(inputStream).getGraph(); - } - - @Override - public void mapNodeType(Onnx.NodeProto tfNode, ImportState importState, - OpImportOverride opImportOverride, - OpImportFilter opFilter) { - val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType()); - if(differentialFunction == null) { - throw new NoOpNameFoundException("No op name found " + tfNode.getOpType()); - } - - val diff = importState.getSameDiff(); - val idx = importState.getGraph().getNodeList().indexOf(tfNode); - val name = !tfNode.getName().isEmpty() ? tfNode.getName() : String.valueOf(idx); - try { - val newInstance = differentialFunction.getClass().newInstance(); - val args = new SDVariable[tfNode.getInputCount()]; - - newInstance.setSameDiff(importState.getSameDiff()); - - newInstance.initFromOnnx(tfNode,diff,getAttrMap(tfNode),importState.getGraph()); - importState.getSameDiff().putOpForId(newInstance.getOwnName(),newInstance); - //ensure we can track node name to function instance later. - diff.setBaseNameForFunctionInstanceId(tfNode.getName(),newInstance); - //diff.addVarNameForImport(tfNode.getName()); - } - catch (Exception e) { - e.printStackTrace(); - } - - - - } - - - - @Override - public DataType dataTypeForTensor(Onnx.TypeProto.Tensor tensorProto, int outputNum) { - return nd4jTypeFromOnnxType(tensorProto.getElemType()); - } - - @Override - public boolean isStringType(Onnx.TypeProto.Tensor tensor) { - return tensor.getElemType() == Onnx.TensorProto.DataType.STRING; - } - - - /** - * Convert an onnx type to the proper nd4j type - * @param dataType the data type to convert - * @return the nd4j type for the onnx type - */ - public DataType nd4jTypeFromOnnxType(Onnx.TensorProto.DataType dataType) { - switch (dataType) { - case DOUBLE: return DataType.DOUBLE; - case FLOAT: return DataType.FLOAT; - case FLOAT16: return DataType.HALF; - case INT32: - case INT64: return DataType.INT; - default: return DataType.UNKNOWN; - } - } - - @Override - public String getAttrValueFromNode(Onnx.NodeProto nodeProto, String key) { - for(Onnx.AttributeProto attributeProto : nodeProto.getAttributeList()) { - if(attributeProto.getName().equals(key)) { - return attributeProto.getS().toString(); - } - } - - throw new ND4JIllegalStateException("No key found for " + key); - } - - @Override - public long[] getShapeFromAttribute(Onnx.AttributeProto attributeProto) { - return Longs.toArray(attributeProto.getT().getDimsList()); - } - - @Override - public boolean isPlaceHolder(Onnx.TypeProto.Tensor nodeType) { - return false; - } - - @Override - public boolean isConstant(Onnx.TypeProto.Tensor nodeType) { - return false; - } - - - @Override - public INDArray getNDArrayFromTensor(String tensorName, Onnx.TypeProto.Tensor tensorProto, Onnx.GraphProto graph) { - DataType type = dataTypeForTensor(tensorProto, 0); - if(!tensorProto.isInitialized()) { - throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized"); - } - - Onnx.TensorProto tensor = null; - for(int i = 0; i < graph.getInitializerCount(); i++) { - val initializer = graph.getInitializer(i); - if(initializer.getName().equals(tensorName)) { - tensor = initializer; - break; - } - } - - if(tensor == null) - return null; - - ByteString bytes = tensor.getRawData(); - ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder()); - ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder()); - directAlloc.put(byteBuffer); - directAlloc.rewind(); - long[] shape = getShapeFromTensor(tensorProto); - DataBuffer buffer = Nd4j.createBuffer(directAlloc,type, ArrayUtil.prod(shape)); - INDArray arr = Nd4j.create(buffer).reshape(shape); - return arr; - } - - public INDArray mapTensorProto(Onnx.TensorProto tensor) { - if(tensor == null) - return null; - - - DataType type = nd4jTypeFromOnnxType(tensor.getDataType()); - - ByteString bytes = tensor.getRawData(); - ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder()); - ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder()); - directAlloc.put(byteBuffer); - directAlloc.rewind(); - long[] shape = getShapeFromTensor(tensor); - DataBuffer buffer = Nd4j.createBuffer(directAlloc,type, ArrayUtil.prod(shape)); - INDArray arr = Nd4j.create(buffer).reshape(shape); - return arr; - } - - @Override - public long[] getShapeFromTensor(onnx.Onnx.TypeProto.Tensor tensorProto) { - val ret = new long[Math.max(2,tensorProto.getShape().getDimCount())]; - int dimCount = tensorProto.getShape().getDimCount(); - if(dimCount >= 2) - for(int i = 0; i < ret.length; i++) { - ret[i] = (int) tensorProto.getShape().getDim(i).getDimValue(); - } - else { - ret[0] = 1; - for(int i = 1; i < ret.length; i++) { - ret[i] = (int) tensorProto.getShape().getDim(i - 1).getDimValue(); - } - } - - - return ret; - } - - - /** - * Get the shape from a tensor proto. - * Note that this is different from {@link #getShapeFromTensor(Onnx.TensorProto)} - * @param tensorProto the tensor to get the shape from - * @return - */ - public long[] getShapeFromTensor(Onnx.TensorProto tensorProto) { - val ret = new long[Math.max(2,tensorProto.getDimsCount())]; - int dimCount = tensorProto.getDimsCount(); - if(dimCount >= 2) - for(int i = 0; i < ret.length; i++) { - ret[i] = (int) tensorProto.getDims(i); - } - else { - ret[0] = 1; - for(int i = 1; i < ret.length; i++) { - ret[i] = (int) tensorProto.getDims(i - 1); - } - } - - - return ret; - } - - @Override - public Set opsToIgnore() { - return Collections.emptySet(); - } - - - @Override - public String getInputFromNode(Onnx.NodeProto node, int index) { - return node.getInput(index); - } - - @Override - public int numInputsFor(Onnx.NodeProto nodeProto) { - return nodeProto.getInputCount(); - } - - - @Override - public long[] getShapeFromAttr(Onnx.AttributeProto attr) { - return Longs.toArray(attr.getT().getDimsList()); - } - - @Override - public Map getAttrMap(Onnx.NodeProto nodeProto) { - Map proto = new HashMap<>(); - for(int i = 0; i < nodeProto.getAttributeCount(); i++) { - Onnx.AttributeProto attributeProto = nodeProto.getAttribute(i); - proto.put(attributeProto.getName(),attributeProto); - } - return proto; - } - - @Override - public String getName(Onnx.NodeProto nodeProto) { - return nodeProto.getName(); - } - - @Override - public boolean alreadySeen(Onnx.NodeProto nodeProto) { - return false; - } - - @Override - public boolean isVariableNode(Onnx.NodeProto nodeProto) { - return nodeProto.getOpType().contains("Var"); - } - - @Override - public boolean shouldSkip(Onnx.NodeProto opType) { - return false; - } - - @Override - public boolean hasShape(Onnx.NodeProto nodeProto) { - return false; - } - - @Override - public long[] getShape(Onnx.NodeProto nodeProto) { - return null; - } - - @Override - public INDArray getArrayFrom(Onnx.NodeProto nodeProto, Onnx.GraphProto graph) { - - return null; - } - - @Override - public String getOpType(Onnx.NodeProto nodeProto) { - return nodeProto.getOpType(); - } - - @Override - public List getNodeList(Onnx.GraphProto graphProto) { - return graphProto.getNodeList(); - } - - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java index 3ad3267c2..be90bb545 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java @@ -16,11 +16,10 @@ package org.nd4j.imports.graphmapper.tf; -import org.nd4j.shade.protobuf.Message; -import org.nd4j.shade.guava.primitives.Floats; -import org.nd4j.shade.guava.primitives.Ints; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.apache.commons.io.IOUtils; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -31,661 +30,620 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser; -import org.nd4j.imports.graphmapper.BaseGraphMapper; -import org.nd4j.imports.graphmapper.ImportState; -import org.nd4j.imports.graphmapper.OpImportFilter; -import org.nd4j.imports.graphmapper.OpImportOverride; import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper; import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMappers; +import org.nd4j.imports.tensorflow.TFImportOverride; +import org.nd4j.imports.tensorflow.TFOpImportFilter; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.controlflow.IfImportState; -import org.nd4j.linalg.exception.ND4JIllegalStateException; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; +import org.nd4j.shade.guava.primitives.Floats; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.protobuf.InvalidProtocolBufferException; +import org.nd4j.shade.protobuf.Message; +import org.nd4j.shade.protobuf.TextFormat; import org.tensorflow.framework.*; -import org.tensorflow.framework.DataType; import java.io.*; -import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.*; /** - * Map tensorflow graph protos - * to the intermediate representation - * for samediff. + * Import a TensorFlow frozen graph in ProtoBuf (.pb) format, to SameDiff * - * @author Adam Gibson + * @author Alex Black */ @Slf4j -public class TFGraphMapper extends BaseGraphMapper { - private Set seenNodes = new LinkedHashSet<>(); - public final static String VALUE_ATTR_KEY = "value"; - public final static String SHAPE_KEY = "shape"; - private static TFGraphMapper MAPPER_INSTANCE = new TFGraphMapper(); - private Set graphMapper = new HashSet(){{ - //While and If - //While -> Enter - /** - * Need to work on coping with variables - * that are marked as "shouldSkip" - * - * Possibly consider replacing should skip - * with a special handler interface. Something like - * - * public interface ImportOpHandler - */ - add("LoopCond"); - /** - * We should skip this for the sake of while..but not if. - * Need to be a bit more flexible here. - */ - add("Merge"); - add("Exit"); - add("NextIteration"); - add("NoOp"); - add("Switch"); - }}; - //singleton - private TFGraphMapper() {} +public class TFGraphMapper { /** - * Singleton. Get the needed instance. - * @return + * Import a frozen TensorFlow protobuf (.pb) file from the specified file + * + * @param f Frozen TensorFlow model pb file to import + * @return Imported graph */ - public static TFGraphMapper getInstance() { - return MAPPER_INSTANCE; + public static SameDiff importGraph(@NonNull File f) { + return importGraph(f, null, null); } - @Override - public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) { - try { - GraphDef graphDef = GraphDef.parseFrom(inputFile); - BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); - for(NodeDef node : graphDef.getNodeList()) { - bufferedWriter.write(node.toString()); - } - - bufferedWriter.flush(); - bufferedWriter.close(); - + /** + * Import a frozen TensorFlow protobuf (.pb) file from the specified file, with optional overrides + * + * @param f Frozen TensorFlow model pb file to import + * @param importOverride Optional import override for specific ops, keyed by op name + * @param opFilter Optional filter - ops to exclude/ignore + * @return Imported graph + */ + public static SameDiff importGraph(@NonNull File f, Map importOverride, TFOpImportFilter opFilter) { + Preconditions.checkState(f.exists(), "File does not exist: %s", f); + try (InputStream is = new BufferedInputStream(new FileInputStream(f))) { + return importGraph(is, importOverride, opFilter); } catch (IOException e) { - e.printStackTrace(); + throw new RuntimeException(e); } } - @Override - public boolean isOpIgnoreException(NodeDef node) { - //if statements should not be ignored -/* - if(node.getOp().equals("Merge")) { - boolean ret = false; - for(int i = 0; i < node.getInputCount(); i++) { - //while loop - ret = ret || !node.getInput(i).endsWith("/Enter") || !node.getInput(i).endsWith("/NextIteration"); + /** + * Import a frozen TensorFlow protobuf (.pb) file, via an input stream + * + * @param is Stream for a frozen TensorFlow model pb file to import + * @return Imported graph + */ + public static SameDiff importGraph(@NonNull InputStream is) { + return importGraph(is, null, null); + } + /** + * Import a frozen TensorFlow protobuf file in text format (.pb.txt) file via an input stream, with optional overrides + * + * @param is Stream for a frozen TensorFlow model pb file to import + * @param importOverride Optional import override for specific ops, keyed by op name + * @param opFilter Optional filter - ops to exclude/ignore + * @return Imported graph + */ + public static SameDiff importGraphTxt(@NonNull InputStream is, Map importOverride, TFOpImportFilter opFilter) { + GraphDef tfGraph; + try { + Message.Builder builder = GraphDef.newBuilder(); + String content = IOUtils.toString(is, StandardCharsets.UTF_8); + TextFormat.getParser().merge(content, builder); + tfGraph = (GraphDef) builder.build(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + return importGraph(tfGraph, importOverride, opFilter); + } + + /** + * Import a frozen TensorFlow protobuf (.pb) file via an input stream, with optional overrides + * + * @param is Stream for a frozen TensorFlow model pb file to import + * @param importOverride Optional import override for specific ops, keyed by op name + * @param opFilter Optional filter - ops to exclude/ignore + * @return Imported graph + */ + public static SameDiff importGraph(@NonNull InputStream is, Map importOverride, TFOpImportFilter opFilter) { + GraphDef tfGraph; + try { + tfGraph = GraphDef.parseFrom(is); + } catch (IOException e) { + throw new RuntimeException(e); + } + + return importGraph(tfGraph, importOverride, opFilter); + } + + /** + * Import a TensorFlow model from a GraphDef + * + * @param tfGraph TensorFlow model GraphDef + * @return Imported model + */ + public static SameDiff importGraph(@NonNull GraphDef tfGraph) { + return importGraph(tfGraph, null, null); + } + + /** + * Import a TensorFlow model from a GraphDef, with optional import overrides + * + * @param tfGraph TensorFlow model GraphDef + * @param importOverride Optional import override for specific ops, keyed by op name + * @param opFilter Optional filter - ops to exclude/ignore + * @return Imported model + */ + public static SameDiff importGraph(@NonNull GraphDef tfGraph, Map importOverride, TFOpImportFilter opFilter) { + + /* + First, build an in-memory representation of the graph that allows us to build the graph incrementally + If we can build the graph incrementally, we can make sure that the added variables are set up with the correct + datatype and (once implemented) greedy shape inference + */ + Set availableToAddSet = new HashSet<>(); //TODO maybe unnecessary? + Queue availableToAdd = new LinkedList<>(); + + Map remainingNodes = new HashMap<>(); //All other nodes, not in availableToAdd + + Map> nodeInputTo = new HashMap<>(); // For op x -> y, x is key, y is value. Note that these are OP names not VARIABLE names + + int nNodes = tfGraph.getNodeCount(); + + //First, add any constants, placeholders, and zero-input ops + SameDiff sd = SameDiff.create(); + for (int i = 0; i < nNodes; i++) { + NodeDef nd = tfGraph.getNode(i); + String op = nd.getOp(); + String name = nd.getName(); + + int nInputs = nd.getInputCount(); + + if ("Const".equals(op) || "Placeholder".equals(op) || nInputs == 0) { + availableToAdd.add(nd); + availableToAddSet.add(name); + } else { + remainingNodes.put(name, nd); + for (int in = 0; in < nInputs; in++) { + String inOpName = stripControl(nd.getInput(in)); + inOpName = stripVarSuffix(inOpName); + + if (!nodeInputTo.containsKey(inOpName)) { + nodeInputTo.put(inOpName, new HashSet()); + } + nodeInputTo.get(inOpName).add(name); + } + } + } + + Map mergeOpsPostProcess = new HashMap<>(); + + //Go through ops in order, and add to the graph + Map> constControlDeps = new HashMap<>(); //Key: constant name. Value: control dependencies + while (!availableToAdd.isEmpty()) { + NodeDef nd = availableToAdd.remove(); + String name = nd.getName(); + String opName = nd.getOp(); + int nIn = nd.getInputCount(); + + availableToAddSet.remove(name); + + log.trace("Adding operation to graph: {} (name={})", opName, name); + + boolean skipCase = false; + if(opFilter != null && opFilter.skipOp(nd, sd, nd.getAttrMap(), tfGraph)){ + log.debug("Skipping op {} of type {} due to op filter", name, opName); + //Don't continue at this point - we still need to process what this feeds into... + skipCase = true; + } else { + if (importOverride == null || !importOverride.containsKey(name)) { + //Standard case + if ("Const".equals(opName)) { + //Get array, create a constant + TensorProto tfTensor = nd.getAttrOrThrow("value").getTensor(); + TFTensorMapper m = TFTensorMappers.newMapper(tfTensor); + INDArray arr = m.toNDArray(); + sd.constant(name, arr); + int inputCount = nd.getInputCount(); + if (inputCount > 0) { + //Very likely control dependency. i.e., "we must execute op X before the constant is really available to be used" + List l = new ArrayList<>(inputCount); + for (int i = 0; i < inputCount; i++) { + String n = nd.getInput(i); + if (!isControlDep(n)) { + throw new IllegalStateException("Found non-control dependency input \"" + n + "\" for constant \"" + name + "\""); + } + String n2 = stripControl(n); + l.add(n2); + } + constControlDeps.put(name, l); + } + } else if ("Placeholder".equals(opName) || "PlaceholderWithDefault".equals(opName)) { + //TODO support the "WithDefault" array + + Map attrMap = nd.getAttrMap(); + boolean shapeAvailable = attrMap.containsKey("shape"); + long[] shape; + if (shapeAvailable) { + TensorShapeProto shapeProto = attrMap.get("shape").getShape(); + shape = shapeFromShapeProto(shapeProto); + } else { + //Some placeholders don't have any shape restrictions - i.e., accept anything... + shape = null; + } + + + org.tensorflow.framework.DataType tfDtype = attrMap.get("dtype").getType(); + org.nd4j.linalg.api.buffer.DataType dt = convertType(tfDtype); + sd.placeHolder(name, dt, shape); + } else { + /* + Normal ops. Process in the following order: + 1. Create the op instance + 2. Add op to graph + 3. Import from TF (to set attributes) + 4. Calculate output dtypes + 5. Create and add output variables to graph + + Note: one constraint on this order is that some ops import modify the graph structure. + Notable example: concat op - it removes the axis op and converts the value to an iArg + https://github.com/eclipse/deeplearning4j/issues/8285 + */ + DifferentialFunction dfInstance = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName); + Preconditions.checkState(dfInstance != null, "Could not find class for TF Ops: {}", opName); + + DifferentialFunction df; + try { + df = dfInstance.getClass().newInstance(); + } catch (Throwable t) { + //Should never happen because function was already created via no-arg constructor earlier + throw new RuntimeException(t); + } + df.setSameDiff(sd); + df.setOwnName(name); + + //Process inputs + List inNames = new ArrayList<>(nIn); + List controlDeps = null; + for (int i = 0; i < nIn; i++) { + String origInName = nd.getInput(i); + String inName = stripControl(origInName); + boolean isControlDep = isControlDep(origInName); + if (isControlDep) { + if (controlDeps == null) + controlDeps = new ArrayList<>(); + controlDeps.add(inName); + } + + if (!isControlDep) { + inNames.add(inName); + } + + //Update Variable.inputsForOp for all variables that feed into this op + // Such variables must have already been created, given we process in order + Variable v = sd.getVariables().get(inName); + + if (v == null && df instanceof Merge) { + //Edge case for import - we allow merge ops to be added before both inputs are available + //This is to break the cycles in loops, otherwise we can't process anything in order + mergeOpsPostProcess.put(df.getOwnName(), inName); + continue; + } + + if (!isControlDep && (v.getInputsForOp() == null || !v.getInputsForOp().contains(name))) { + //May already be present - for example, add(x,x) + if (v.getInputsForOp() == null) + v.setInputsForOp(new ArrayList()); + v.getInputsForOp().add(name); + } else if (isControlDep) { + if (v.getControlDepsForOp() == null) + v.setControlDepsForOp(new ArrayList()); + if (!v.getControlDepsForOp().contains(name)) { + v.getControlDepsForOp().add(name); + } + } + } + + //Create SameDiffOp instance and add to graph + SameDiffOp op = SameDiffOp.builder() + .name(name) + .op(df) + .inputsToOp(inNames) + //.outputsOfOp(outNames) //We'll set this later + .controlDeps(controlDeps) + .build(); + sd.getOps().put(name, op); + + + Map attrMap = nd.getAttrMap(); + df.initFromTensorFlow(nd, sd, attrMap, tfGraph); //TODO REMOVE TFGRAPH ENTIRELY FROM THIS CALL - it encourages hacky and really brittle stuff like input array to attribute conversion + + //DType calculate for output variables (set/correct if necessary) + List newInNames = sd.getOps().get(name).getInputsToOp(); //Just in case import has modified this, like for concat case + List newInDtypes = new ArrayList<>(newInNames.size()); + if (df instanceof Merge) { + //Merge op: as noted elsewhere, we allow merge to be processed when only one of the inputs is available + // to break cycles for loops + //We know that Merge op has the restriction of the same datatype for both inputs, so we'll + SDVariable v1 = sd.getVariable(newInNames.get(0)); + SDVariable v2 = sd.getVariable(newInNames.get(1)); + org.nd4j.linalg.api.buffer.DataType dt1 = (v1 == null ? v2.dataType() : v1.dataType()); + org.nd4j.linalg.api.buffer.DataType dt2 = (v2 == null ? v1.dataType() : v2.dataType()); + newInDtypes.add(dt1); + newInDtypes.add(dt2); + } else { + for (String s : newInNames) { + SDVariable v = sd.getVariable(s); + newInDtypes.add(v.dataType()); + } + } + + List outDTypes = df.calculateOutputDataTypes(newInDtypes); + SDVariable[] outSDVars = new SDVariable[outDTypes.size()]; + Variable[] outVars = new Variable[outDTypes.size()]; + List outNames = new ArrayList<>(outDTypes.size()); + + //Create output variables and add to graph + for (int i = 0; i < outDTypes.size(); i++) { + org.nd4j.linalg.api.buffer.DataType dt = outDTypes.get(i); + String varName = name + (i == 0 ? "" : ":" + i); + outSDVars[i] = sd.var(varName, VariableType.ARRAY, null, dt, (long[]) null); + outNames.add(varName); + + outVars[i] = Variable.builder() + .name(varName) + .variable(outSDVars[i]) + .inputsForOp(null) //This is updated incrementally as other ops are added + .controlDepsForOp(null) //Control deps are handled later + .controlDepsForVar(null) + .outputOfOp(name) + .build(); + + sd.getVariables().put(varName, outVars[i]); + log.trace("Added variable to graph: {} (output of op {})", varName, name); + } + sd.getOps().get(name).setOutputsOfOp(outNames); + + log.trace("Imported op: {} (name={})", opName, name); + } + } else { + //Import override case + TFImportOverride o = importOverride.get(name); + + log.debug("Importing op {} using override {}", opName, importOverride); + + //First, get inputs: + List inputs = new ArrayList<>(nIn); + List controlDeps = null; + for (int i = 0; i < nIn; i++) { + String inName = nd.getInput(i); + boolean controlDep = isControlDep(inName); + + SDVariable v = sd.getVariable(name); + + if (controlDep) { + if (controlDeps == null) + controlDeps = new ArrayList<>(); + controlDeps.add(v); + } else { + inputs.add(v); + } + + o.initFromTensorFlow(inputs, controlDeps, nd, sd, nd.getAttrMap(), tfGraph); + } + } } - return ret; - } - else if(node.getOp().equals("Switch")) { - boolean ret = false; - for(int i = 0; i < node.getInputCount(); i++) { - //while loop - ret = ret || !node.getInput(i).endsWith("/Merge") || !node.getInput(i).endsWith("/LoopCond"); + //Now that we have just added an op (or variable) - check what this feeds into, and see what we can now process + // as a result + if (nodeInputTo.containsKey(name)) { + Set set = nodeInputTo.get(name); + for (String nextOp : set) { + NodeDef nextOpDef = remainingNodes.get(nextOp); + if (nextOpDef == null) { + if (sd.getOps().containsKey(nextOp)) { + //Already processed this. + //Almost certainly the close of a loop - like NextIteration -> Merge case + continue; + } + //Should never happen + throw new IllegalStateException("Could not find op definition for op to import: " + nextOp); + } + int nInNext = nextOpDef.getInputCount(); + boolean allAlreadyInGraph = true; + int nonControlSeenCount = 0; + for (int i = 0; i < nInNext; i++) { + String s = nextOpDef.getInput(i); + String inName = stripControl(nextOpDef.getInput(i)); + +// log.info("Input: {}, {}", s, inName); + + if (!sd.hasVariable(inName) && !skipCase) { +// log.info("Not found: {} for op {}", inName, nextOpDef.getName()); + allAlreadyInGraph = false; + break; + } else if (!isControlDep(s)) { + nonControlSeenCount++; + } + } + + //Merge ops are an edge case. We'll allow these to be executed with just ONE input, to break + // the cycle in loops. In loops, generally we have (Enter, NextIteration) -> Merge, which + // of course can't be done if we strictly require all inputs to be available + boolean mergeCase = (nonControlSeenCount > 0 && "Merge".equals(nextOpDef.getOp())); + + if (allAlreadyInGraph || mergeCase) { + //Can process this op, add it to the queue for processing + if (!availableToAddSet.contains(nextOp)) { + //Avoid processing same op multiple times, for repeated inputs to one op, etc + availableToAdd.add(nextOpDef); + availableToAddSet.add(nextOp); + log.trace("Added to processing queue: {} (name={})", nextOpDef.getOp(), nextOp); + } + } + } } + //Finally, remove the just processed op from remainingNodes map: + remainingNodes.remove(name); + } + + //Post process the control dependencies, if any (done after because dependencies may not exist when imported) + for (Map.Entry> e : constControlDeps.entrySet()) { + String varName = e.getKey(); + List cdOpNames = e.getValue(); + sd.getVariables().get(varName).setControlDeps(cdOpNames); + + for (String s : cdOpNames) { + SameDiffOp sdo = sd.getOps().get(s); + if (sdo.getControlDepFor() == null) + sdo.setControlDepFor(new ArrayList()); + List l = sdo.getControlDepFor(); + if (!l.contains(s)) + l.add(varName); + } + } + + //Post process the merge ops - all we are missing is a Variable.getInputsForOp().add(mergeOpName); + for (Map.Entry e : mergeOpsPostProcess.entrySet()) { + Variable v = sd.getVariables().get(e.getValue()); + if (v.getInputsForOp() == null) + v.setInputsForOp(new ArrayList()); + v.getInputsForOp().add(e.getKey()); + } + + Preconditions.checkState(remainingNodes.isEmpty(), "%s Unprocessed nodes: %s", remainingNodes.size(), remainingNodes.keySet()); + + return sd; + } + + + /** + * Get the shape from a TensorShapeProto + * + * @param tensorShapeProto Shape + * @return Shape as long[] + */ + private static long[] shapeFromShapeProto(TensorShapeProto tensorShapeProto) { + long[] shape = new long[tensorShapeProto.getDimList().size()]; + for (int i = 0; i < shape.length; i++) { + shape[i] = tensorShapeProto.getDim(i).getSize(); + } + + return shape; + } + + /** + * Convert from TF proto datatype to ND4J datatype + * + * @param tfType TF datatype + * @return ND4J datatype + */ + public static org.nd4j.linalg.api.buffer.DataType convertType(org.tensorflow.framework.DataType tfType) { + switch (tfType) { + case DT_DOUBLE: + return org.nd4j.linalg.api.buffer.DataType.DOUBLE; + case DT_FLOAT: + return org.nd4j.linalg.api.buffer.DataType.FLOAT; + case DT_HALF: + return org.nd4j.linalg.api.buffer.DataType.HALF; + case DT_BFLOAT16: + return org.nd4j.linalg.api.buffer.DataType.BFLOAT16; + case DT_INT8: + return org.nd4j.linalg.api.buffer.DataType.BYTE; + case DT_INT16: + return org.nd4j.linalg.api.buffer.DataType.SHORT; + case DT_INT32: + return org.nd4j.linalg.api.buffer.DataType.INT; + case DT_INT64: + return org.nd4j.linalg.api.buffer.DataType.LONG; + case DT_UINT8: + return org.nd4j.linalg.api.buffer.DataType.UBYTE; + case DT_STRING: + return org.nd4j.linalg.api.buffer.DataType.UTF8; + case DT_BOOL: + return org.nd4j.linalg.api.buffer.DataType.BOOL; + + default: + return org.nd4j.linalg.api.buffer.DataType.UNKNOWN; + } + } + + /** + * @return True if the specified name represents a control dependency (starts with "^") + */ + protected static boolean isControlDep(String name) { + return name.startsWith("^"); + } + + /** + * @return The specified name without the leading "^" character (if any) that appears for control dependencies + */ + protected static String stripControl(String name) { + if (name.startsWith("^")) { + return name.substring(1); + } + return name; + } + + /** + * Remove the ":1" etc suffix for a variable name to get the op name + * + * @param varName Variable name + * @return Variable name without any number suffix + */ + protected static String stripVarSuffix(String varName) { + if (varName.matches(".*:\\d+")) { + int idx = varName.lastIndexOf(':'); + String ret = varName.substring(0, idx); return ret; } -*/ - return true; + return varName; } - @Override - public String getTargetMappingForOp(DifferentialFunction function, NodeDef node) { - return function.opName(); + /** + * Convert the tensor to an NDArray (if possible and if array is available) + * + * @param node Node to get NDArray from + * @return NDArray + */ + public static INDArray getNDArrayFromTensor(NodeDef node) { + //placeholder of some kind + if (!node.getAttrMap().containsKey("value")) { + return null; + } + + val tfTensor = node.getAttrOrThrow("value").getTensor(); + INDArray out = mapTensorProto(tfTensor); + return out; } - @Override - public NodeDef getNodeWithNameFromGraph(GraphDef graph, String name) { - for(int i = 0; i < graph.getNodeCount(); i++) { + /** + * Convert a TensorProto to an INDArray + * + * @param tfTensor Tensor proto + * @return INDArray + */ + public static INDArray mapTensorProto(TensorProto tfTensor) { + TFTensorMapper m = TFTensorMappers.newMapper(tfTensor); + if (m == null) { + throw new RuntimeException("Not implemented datatype: " + tfTensor.getDtype()); + } + INDArray out = m.toNDArray(); + return out; + } + + @Deprecated //To be removed + public static NodeDef getNodeWithNameFromGraph(GraphDef graph, String name) { + for (int i = 0; i < graph.getNodeCount(); i++) { val node = graph.getNode(i); - if(node.getName().equals(name)) + if (node.getName().equals(name)) return node; } return null; } - @Override - public void mapProperty(String name, DifferentialFunction on, NodeDef node, GraphDef graph, SameDiff sameDiff, Map> propertyMappingsForFunction) { - if(node == null) { - throw new ND4JIllegalStateException("No node found for name " + name); - } - - - val mapping = propertyMappingsForFunction.get(getOpType(node)).get(name); - val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on); - - - if(mapping.getTfInputPosition() != null && mapping.getTfInputPosition() < node.getInputCount()) { - int tfMappingIdx = mapping.getTfInputPosition(); - if(tfMappingIdx < 0) - tfMappingIdx += node.getInputCount(); - - val input = node.getInput(tfMappingIdx); - val inputNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,input); - INDArray arr = getArrayFrom(inputNode,graph); - if(arr == null && sameDiff.hasVariable(input)) { - arr = sameDiff.getArrForVarName(input); - } - - if(arr == null && inputNode != null) { - sameDiff.addPropertyToResolve(on,name); - sameDiff.addVariableMappingForField(on,name,getNodeName(inputNode.getName())); - return; - } else if(inputNode == null) { - //TODO need to do anything here given new design? - //sameDiff.addAsPlaceHolder(input); - return; - } - - val field = fields.get(name); - val type = field.getType(); - if(type.equals(int[].class)) { - on.setValueFor(field,arr.data().asInt()); - } - else if(type.equals(int.class) || type.equals(long.class) || type.equals(Long.class) || type.equals(Integer.class)) { - if(mapping.getShapePosition() != null) { - on.setValueFor(field,arr.size(mapping.getShapePosition())); - } - else - on.setValueFor(field,arr.getInt(0)); - - } - else if(type.equals(float.class) || type.equals(double.class) || type.equals(Float.class) || type.equals(Double.class)) { - on.setValueFor(field,arr.getDouble(0)); - } - - - } - else { - val tfMappingAttrName = mapping.getTfAttrName(); - if(tfMappingAttrName == null) { - return; - } - - if(!node.containsAttr(tfMappingAttrName)) { - return; - } - - - val attr = node.getAttrOrThrow(tfMappingAttrName); - val type = attr.getType(); - if(fields == null) { - throw new ND4JIllegalStateException("No fields found for op [" + mapping + "]"); - } - - if(mapping.getPropertyNames() == null) { - throw new ND4JIllegalStateException("no property found for [" + name + "] in op [" + on.opName()+"]"); - } - - val field = fields.get(mapping.getPropertyNames()[0]); - - Object valueToSet = null; - switch(type) { - case DT_BOOL: - valueToSet = attr.getB(); - break; - case DT_INT8: - valueToSet = attr.getI(); - break; - case DT_INT16: - valueToSet = attr.getI(); - break; - case DT_INT32: - valueToSet = attr.getI(); - break; - case DT_FLOAT: - valueToSet = attr.getF(); - break; - case DT_DOUBLE: - valueToSet = attr.getF(); - break; - case DT_STRING: - valueToSet = attr.getS(); - break; - case DT_INT64: - valueToSet = attr.getI(); - break; - - - } - - if(field != null && valueToSet != null) - on.setValueFor(field,valueToSet); - } - } - - - /** - * {@inheritDoc} - */ - @Override - public boolean isPlaceHolderNode(NodeDef node) { - return node.getOp().startsWith("Placeholder"); - } - - - /** - * {@inheritDoc} - */ - @Override - public void dumpBinaryProtoAsText(File inputFile, File outputFile) { - try { - GraphDef graphDef = GraphDef.parseFrom(new BufferedInputStream(new FileInputStream(inputFile))); - BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); - for(NodeDef node : graphDef.getNodeList()) { - bufferedWriter.write(node.toString()); - } - - bufferedWriter.flush(); - bufferedWriter.close(); - - } catch (IOException e) { - e.printStackTrace(); - } - } - - @Override - public long[] getShapeFromAttr(AttrValue attr) { - return shapeFromShapeProto(attr.getShape()); - } - - @Override - public Map getAttrMap(NodeDef nodeDef) { - return nodeDef.getAttrMap(); - } - - @Override - public String getName(NodeDef nodeDef) { - return nodeDef.getName(); - } - - @Override - public boolean alreadySeen(NodeDef nodeDef) { - return seenNodes.contains(nodeDef.getName()); - } - - @Override - public boolean isVariableNode(NodeDef nodeDef) { - boolean isVar = nodeDef.getOp().startsWith("VariableV") || nodeDef.getOp().equalsIgnoreCase("const"); - return isVar; - } - - @Override - public boolean shouldSkip(NodeDef opType) { - if(opType == null) - return true; - - boolean endsWithRead = opType.getName().endsWith("/read"); - return endsWithRead; - } - - @Override - public boolean hasShape(NodeDef nodeDef) { - return nodeDef.containsAttr(SHAPE_KEY); - } - - @Override - public long[] getShape(NodeDef nodeDef) { - return getShapeFromAttr(nodeDef.getAttrOrThrow(SHAPE_KEY)); - } - - @Override - public INDArray getArrayFrom(NodeDef nodeDef, GraphDef graph) { - if(nodeDef == null) { + @Deprecated //To be removed + public static INDArray getArrayFrom(NodeDef nodeDef, GraphDef graph) { + if (nodeDef == null) { return null; } - return getNDArrayFromTensor(nodeDef.getName(),nodeDef, graph); - } - - @Override - public String getOpType(NodeDef nodeDef) { - return nodeDef.getOp(); - } - - /** - * - * @param graphDef - * @return - */ - @Override - public List getNodeList(GraphDef graphDef) { - return graphDef.getNodeList(); - } - - /** - * - * @param name the tensorflow or onnx name - * @return - */ - @Override - public DifferentialFunction getMappedOp(String name) { - return DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(name); - } - - - /** - * Map a tensorflow node name - * to the samediff equivalent - * for import - * @param name the name to change - * @return the input tensorflow name - */ - public String getNodeName(String name) { - //tensorflow adds colons to the end of variables representing input index, this strips those off - String ret = name; - if(ret.startsWith("^")) - ret = ret.substring(1); - if(ret.endsWith("/read")) { - ret = ret.replace("/read",""); - } - if(ret.endsWith(":0")){ - ret = ret.substring(0, ret.length()-2); - } - return ret; - } - - public boolean isControlDependency(String name){ - return name.startsWith("^"); - } - - - - @Override - public Map variablesForGraph(GraphDef graphDef) { - Map ret = new LinkedHashMap<>(); - List nodeList = graphDef.getNodeList(); - for(NodeDef nodeDef : nodeList) { - if(nodeDef.getName().endsWith("/read")) { - continue; - } - - - val name = translateToSameDiffName(nodeDef.getName(), nodeDef); - ret.put(name,nodeDef); - } - - return ret; - } - - @Override - public String translateToSameDiffName(String name, NodeDef node) { - if(isVariableNode(node) || isPlaceHolder(node)) { - return name; - } - - StringBuilder stringBuilder = new StringBuilder(); - //strip arg number - if(name.contains(":")) { - name = name.substring(0,name.lastIndexOf(':')); - stringBuilder.append(name); - } - else { - stringBuilder.append(name); - } - - - return stringBuilder.toString(); - } - - //Strip the variable suffix to give the node name: "Unique:1" -> "Unique" - public String varNameToOpName(String varName){ - int idx = varName.lastIndexOf(':'); - if(idx < 0) - return varName; - return varName.substring(0, idx); - } - - public static int varNameToOpOutputNumber(String varName){ - int idx = varName.lastIndexOf(':'); - if(idx < 0) - return 0; - String n = varName.substring(idx+1); - return Integer.parseInt(n); - } - - - @Override - public Message.Builder getNewGraphBuilder() { - return GraphDef.newBuilder(); - } - - @Override - public GraphDef parseGraphFrom(byte[] inputStream) throws IOException { - return GraphDef.parseFrom(inputStream); - } - - @Override - public GraphDef parseGraphFrom(InputStream inputStream) throws IOException { - return GraphDef.parseFrom(inputStream); - } - - protected void importCondition(String conditionName, NodeDef tfNode, ImportState importState) { - /** - * Cond structure: - * - */ - } - - @Override - public void mapNodeType(NodeDef tfNode, ImportState importState, - OpImportOverride importOverride, - OpImportFilter opFilter) { - if (shouldSkip(tfNode) || alreadySeen(tfNode) || isVariableNode(tfNode)) { - return; - } - - - SameDiff diff = importState.getSameDiff(); - if (isVariableNode(tfNode)) { - List dimensions = new ArrayList<>(); - Map attributes = getAttrMap(tfNode); - if (attributes.containsKey(VALUE_ATTR_KEY)) { - diff.var(getName(tfNode),getArrayFrom(tfNode,importState.getGraph())); - } - else if (attributes.containsKey(SHAPE_KEY)) { - AttrValue shape = attributes.get(SHAPE_KEY); - long[] shapeArr = getShapeFromAttr(shape); - int dims = shapeArr.length; - if (dims > 0) { - // even vector is 2d in nd4j - if (dims == 1) - dimensions.add(1L); - - for (int e = 0; e < dims; e++) { - // TODO: eventually we want long shapes :( - dimensions.add(getShapeFromAttr(shape)[e]); - } - } - } - } - - else if(isPlaceHolder(tfNode)) { - SDVariable var = diff.getVariable(getName(tfNode)); - Preconditions.checkState(var.isPlaceHolder(), "Variable should be marked as placeholder at this point: %s", var); - } else { - val opName = tfNode.getOp(); - - if(importOverride != null){ - //First, get inputs: - int numInputs = tfNode.getInputCount(); - List inputs = new ArrayList<>(numInputs); - List controlDeps = null; - for( int i=0; i this) - if (v == null) { - //Check 'op skip' edge case - boolean shouldSkip = false; - if(opFilter != null){ - //Get the input node - List l = importState.getGraph().getNodeList(); - NodeDef inputNodeDef = null; - for(NodeDef nd : l){ - if(inName.equals(nd.getName())){ - inputNodeDef = nd; - break; - } - } - Preconditions.checkState(inputNodeDef != null, "Could not find node with name \"%s\"", inName); - shouldSkip = true; - } - - if(!shouldSkip) { - //First: try to work out the datatype of this input node - //Given we haven't already imported it at this point, it must be the 2nd or later output of an op - - String inputOpName = varNameToOpName(inName); - NodeDef inputOp = importState.getVariables().get(inputOpName); - int outputIdx = varNameToOpOutputNumber(name); - org.nd4j.linalg.api.buffer.DataType dt = dataTypeForTensor(inputOp, outputIdx); - if (dt == org.nd4j.linalg.api.buffer.DataType.UNKNOWN) - dt = null; //Infer it later - - - v = diff.var(name, VariableType.ARRAY, null, dt, (long[]) null); - } - } - - if(controlDep){ - if(controlDeps == null) - controlDeps = new ArrayList<>(); - controlDeps.add(v); - } else { - inputs.add(v); - } - } - - log.info("Importing op {} using override {}", opName, importOverride); - importOverride.initFromTensorFlow(inputs, controlDeps, tfNode, diff, getAttrMap(tfNode), importState.getGraph()); - } else { - - val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName); - if (differentialFunction == null) { - throw new ND4JIllegalStateException("No tensorflow op found for " + opName + " possibly missing operation class?"); - } - try { - DifferentialFunction newInstance = differentialFunction.getClass().newInstance(); - List args = new ArrayList<>(); - List controlDeps = null; - newInstance.setOwnName(tfNode.getName()); - - int x = 0; - for (int i = 0; i < tfNode.getInputCount(); i++) { - String inName = tfNode.getInput(i); - String inputOpName = varNameToOpName(inName); - NodeDef inputNode = importState.getVariables().get(inputOpName); - - if (shouldSkip(inputNode) && !inName.endsWith("/read")) - continue; - - boolean controlDep = isControlDependency(inName); - String name = getNodeName(inName); - - SDVariable v = diff.getVariable(name); - - //At this point, all placeholders, variables and constants should have been imported - //This: this should be an array type variable (i.e., activations) - if (v == null) { - //First: try to work out the datatype of this input node - //Given we haven't already imported it at this point, it must be the 2nd or later output of an op - - NodeDef inputOp = importState.getVariables().get(inputOpName); - int outputIdx = varNameToOpOutputNumber(name); - org.nd4j.linalg.api.buffer.DataType dt = dataTypeForTensor(inputOp, outputIdx); - if (dt == org.nd4j.linalg.api.buffer.DataType.UNKNOWN) - dt = null; //Infer it later - - - v = diff.var(name, VariableType.ARRAY, null, dt, (long[]) null); - } - - if (controlDep) { - //Is only a control dependency input to op, not a real data input - if (controlDeps == null) - controlDeps = new ArrayList<>(); - if (!controlDeps.contains(name)) - controlDeps.add(name); - } else { - //Is a standard/"real" op input - args.add(v); - } - } - - - diff.addArgsFor(args.toArray(new SDVariable[args.size()]), newInstance); - newInstance.setSameDiff(importState.getSameDiff()); - - if (controlDeps != null) { - SameDiffOp op = diff.getOps().get(newInstance.getOwnName()); - op.setControlDeps(controlDeps); - - //Also record this on the variables: - for (String s : controlDeps) { - Variable v = diff.getVariables().get(s); - if (v.getControlDepsForOp() == null) - v.setControlDeps(new ArrayList()); - List l = v.getControlDepsForOp(); - if (!l.contains(op.getName())) - l.add(op.getName()); - } - } - - newInstance.initFromTensorFlow(tfNode, diff, getAttrMap(tfNode), importState.getGraph()); - mapProperties(newInstance, tfNode, importState.getGraph(), importState.getSameDiff(), newInstance.mappingsForFunction()); - importState.getSameDiff().putOpForId(newInstance.getOwnName(), newInstance); - //ensure we can track node name to function instance later. - diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance); - } catch (Exception e) { - log.error("Failed to import op [{}]", opName); - throw new RuntimeException(e); - } - } - } - } - - - /** - * Calls {@link #initFunctionFromProperties(DifferentialFunction, Map, NodeDef, GraphDef)} - * using {@link DifferentialFunction#tensorflowName()} - * @param on the function to use init on - * @param attributesForNode the attributes for the node - * @param node - * @param graph - */ - public void initFunctionFromProperties(DifferentialFunction on, Map attributesForNode, NodeDef node, GraphDef graph) { - initFunctionFromProperties(on.tensorflowName(),on,attributesForNode,node,graph); + return getNDArrayFromTensor(nodeDef); } /** * Init a function's attributes - * @param mappedTfName the tensorflow name to pick (sometimes ops have multiple names - * @param on the function to map + * + * @param mappedTfName the tensorflow name to pick (sometimes ops have multiple names + * @param on the function to map * @param attributesForNode the attributes for the node * @param node * @param graph + * @deprecated To be removed */ - public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map attributesForNode, NodeDef node, GraphDef graph) { + @Deprecated + public static void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map attributesForNode, NodeDef node, GraphDef graph) { val properties = on.mappingsForFunction(); val tfProperties = properties.get(mappedTfName); val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on); @@ -699,8 +657,8 @@ public class TFGraphMapper extends BaseGraphMapper need to map data format before mapping strides //Solution: map nodes without adapters before nodes with adapters. This doesn't guarantee we'll always be // mapping in the right order (for example, we might have adapter(x) depends on adapter(y)) but it should catch most cases - Map map; - if(attributeAdapters == null || !attributeAdapters.containsKey(mappedTfName)) { + Map map; + if (attributeAdapters == null || !attributeAdapters.containsKey(mappedTfName)) { map = tfProperties; } else { map = new LinkedHashMap<>(); @@ -718,24 +676,24 @@ public class TFGraphMapper extends BaseGraphMapper entry : map.entrySet()){ + for (Map.Entry entry : map.entrySet()) { val tfAttrName = entry.getValue().getTfAttrName(); val currentField = fields.get(entry.getKey()); AttributeAdapter adapter = null; - if(attributeAdapters != null && !attributeAdapters.isEmpty()) { + if (attributeAdapters != null && !attributeAdapters.isEmpty()) { val mappers = attributeAdapters.get(mappedTfName); val adapterFor = mappers.get(entry.getKey()); adapter = adapterFor; } - if(tfAttrName != null) { - if(currentField == null) { + if (tfAttrName != null) { + if (currentField == null) { continue; } - if(attributesForNode.containsKey(tfAttrName)) { + if (attributesForNode.containsKey(tfAttrName)) { val attr = attributesForNode.get(tfAttrName); switch (attr.getValueCase()) { case B: @@ -743,77 +701,69 @@ public class TFGraphMapper extends BaseGraphMapper 0){ - for(int i=0; i 0){ //Looks like a few OpDef instances have outputs but don't actually list them... example: NoOp - Preconditions.checkState(outNum < actualOutputCount, "Cannot get output argument %s from op %s with %s output variables - variable %s", outNum, actualOutputCount, tensorProto.getName(), tensorProto.getName()); - - int argIdx = outNum; - if(outputArgCount != actualOutputCount){ - //Map backwards accunting for fact that each output arg might correspond to multiple variables: for output variable x, which argument is this? - int idx = 0; - int soFar = 0; - while(soFar + outVarsPerOutputArg[idx] <= outNum){ - soFar += outVarsPerOutputArg[idx++]; - } - argIdx = idx; - } - - OpDef.ArgDef argDef = opDef.getOutputArg(argIdx); - String typeAttr = argDef.getTypeAttr(); - if(typeAttr != null && tensorProto.containsAttr(typeAttr)){ - tfType = tensorProto.getAttrOrThrow(typeAttr).getType(); - } else { - return org.nd4j.linalg.api.buffer.DataType.UNKNOWN; - } - - } else { - if(tensorProto.getOp().equals("NoOp")){ - return org.nd4j.linalg.api.buffer.DataType.UNKNOWN; - } else if(tensorProto.getOp().equals("Assert")){ - return org.nd4j.linalg.api.buffer.DataType.BOOL; - } - //Not in ops.proto - log.debug("No TensorFlow descriptor found for tensor \"{}\", op \"{}\"", tensorProto.getName(), tensorProto.getOp()); - - //No descriptor... try to fall back on common type attribute names - if(!tensorProto.containsAttr("dtype") && !tensorProto.containsAttr("Tidx") && !tensorProto.containsAttr("T")) - return org.nd4j.linalg.api.buffer.DataType.UNKNOWN; - - tfType = tensorProto.containsAttr("dtype") ? tensorProto.getAttrOrThrow("dtype").getType() - : tensorProto.containsAttr("T") ? tensorProto.getAttrOrThrow("T").getType() : tensorProto - .getAttrOrThrow("Tidx").getType(); - } - - return convertType(tfType); - } - - public static org.nd4j.linalg.api.buffer.DataType convertType(org.tensorflow.framework.DataType tfType){ - switch(tfType) { - case DT_DOUBLE: return org.nd4j.linalg.api.buffer.DataType.DOUBLE; - case DT_FLOAT: return org.nd4j.linalg.api.buffer.DataType.FLOAT; - case DT_HALF: return org.nd4j.linalg.api.buffer.DataType.HALF; - case DT_BFLOAT16: return org.nd4j.linalg.api.buffer.DataType.BFLOAT16; - case DT_INT8: return org.nd4j.linalg.api.buffer.DataType.BYTE; - case DT_INT16: return org.nd4j.linalg.api.buffer.DataType.SHORT; - case DT_INT32: return org.nd4j.linalg.api.buffer.DataType.INT; - case DT_INT64: return org.nd4j.linalg.api.buffer.DataType.LONG; - case DT_UINT8: return org.nd4j.linalg.api.buffer.DataType.UBYTE; - case DT_STRING: return org.nd4j.linalg.api.buffer.DataType.UTF8; - case DT_BOOL: return org.nd4j.linalg.api.buffer.DataType.BOOL; - - default: return org.nd4j.linalg.api.buffer.DataType.UNKNOWN; - } - } - - @Override - public boolean isStringType(NodeDef tensorProto){ - DataType dt = null; - if(tensorProto.containsAttr("dtype")){ - dt = tensorProto.getAttrOrThrow("dtype").getType(); - } else if(tensorProto.containsAttr("T")){ - dt = tensorProto.getAttrOrThrow("T").getType(); - } else if(tensorProto.containsAttr("Tidx")){ - dt = tensorProto.getAttrOrThrow("Tidx").getType(); - } - - return dt == DataType.DT_STRING || dt == DataType.DT_STRING_REF; - } - - - @Override - public String getAttrValueFromNode(NodeDef nodeDef, String key) { - return nodeDef.getAttrOrThrow(key).getS().toStringUtf8(); - } - - @Override - public long[] getShapeFromAttribute(AttrValue attrValue) { - TensorShapeProto shape = attrValue.getShape(); - long[] ret = new long[shape.getDimCount()]; - for(int i = 0; i < ret.length; i++) { - ret[i] = (int) shape.getDim(i).getSize(); + if (ret.endsWith(":0")) { + ret = ret.substring(0, ret.length() - 2); } return ret; } - @Override - public boolean isPlaceHolder(NodeDef nodeDef) { - return nodeDef.getOp().startsWith("Placeholder"); + /** + * Determine if the node represents a variable node (based on op name) + * + * @param nodeDef Node to check if a variable + * @return True if a variable node + */ + public static boolean isVariableNode(NodeDef nodeDef) { + boolean isVar = nodeDef.getOp().startsWith("VariableV") || nodeDef.getOp().equalsIgnoreCase("const"); + return isVar; } - @Override - public boolean isConstant(NodeDef nodeDef) { - return nodeDef.getOp().startsWith("Const"); - } - - @Override - public List getControlDependencies(NodeDef node){ - int numInputs = node.getInputCount(); - if(numInputs == 0) - return null; - - List out = null; - for( int i=0; i(); - out.add(getNodeName(in)); //Remove "^" prefix - } - } - return out; - } - - @Override - public INDArray getNDArrayFromTensor(String tensorName, NodeDef node, GraphDef graph) { - //placeholder of some kind - if(!node.getAttrMap().containsKey("value")) { - return null; - } - - val tfTensor = node.getAttrOrThrow("value").getTensor(); - INDArray out = mapTensorProto(tfTensor); - return out; - } - - - - public INDArray mapTensorProto(TensorProto tfTensor) { - - TFTensorMapper m = TFTensorMappers.newMapper(tfTensor); - if(m == null){ - throw new RuntimeException("Not implemented datatype: " + tfTensor.getDtype()); - } - INDArray out = m.toNDArray(); - return out; - } - - protected static void setFloat16ValueFromInt(INDArray arr, int idx, int bytesAsPaddedInt){ - ByteBuffer bb = arr.data().pointer().asByteBuffer(); - bb.put(2*idx, (byte)((bytesAsPaddedInt >> 8) & 0xff)); - bb.put(2*idx+1, (byte)(bytesAsPaddedInt & 0xff)); - } - - @Override - public long[] getShapeFromTensor(NodeDef tensorProto) { - if(tensorProto.containsAttr("shape")) { - return shapeFromShapeProto(tensorProto.getAttrOrThrow("shape").getShape()); - - } - //yet to be determined shape, or tied to an op where output shape is dynamic - else if(!tensorProto.containsAttr("value")) { - return null; - - } - else - return shapeFromShapeProto(tensorProto.getAttrOrThrow("value").getTensor().getTensorShape()); - } - - @Override - public Set opsToIgnore() { - return graphMapper; - } - - - @Override - public String getInputFromNode(NodeDef node, int index) { - return node.getInput(index); - } - - @Override - public int numInputsFor(NodeDef nodeDef) { - return nodeDef.getInputCount(); - } - - private long[] shapeFromShapeProto(TensorShapeProto tensorShapeProto) { - long[] shape = new long[tensorShapeProto.getDimList().size()]; - for(int i = 0; i < shape.length; i++) { - shape[i] = tensorShapeProto.getDim(i).getSize(); - } - - return shape; - } - - /** - * Returns the node for an if statement - * @param from the starting node (a merge node that represents a conditional) - * @param graph the graph to search - * @return an import state representing the nodes for each scope + * Determine if the node is a placeholder + * + * @param nodeDef Node to check + * @return True if the node is a placeholder */ - public IfImportState nodesForIf(NodeDef from, GraphDef graph) { - //Assume we start with a switch statement - int currNodeIndex = graph.getNodeList().indexOf(from); - val trueDefName = from.getInput(1); - val falseDefName = from.getInput(0); - val scopeId = UUID.randomUUID().toString(); - val scopeName = scopeId + "-" + trueDefName.substring(0,trueDefName.indexOf("/")); - val trueDefScopeName = scopeName + "-true-scope"; - val falseDefScopeName = scopeName + "-false-scope"; - - - boolean onFalseDefinition = true; - //start with the true - boolean onTrueDefinition = false; - - List falseBodyNodes = new ArrayList<>(); - List trueBodyNodes = new ArrayList<>(); - List conditionNodes = new ArrayList<>(); - Set seenNames = new LinkedHashSet<>(); - /** - * Accumulate a list backwards to get proper ordering. - * - */ - for(int i = currNodeIndex; i >= 0; i--) { - //switch to false names - if(graph.getNode(i).getName().equals(trueDefName)) { - onFalseDefinition = false; - onTrueDefinition = true; - } - - //on predicate now - if(graph.getNode(i).getName().contains("pred_id")) { - onTrueDefinition = false; - } - //don't readd the same node, this causes a stackoverflow - if(onTrueDefinition && !graph.getNode(i).equals(from)) { - trueBodyNodes.add(graph.getNode(i)); - } - else if(onFalseDefinition && !graph.getNode(i).equals(from)) { - falseBodyNodes.add(graph.getNode(i)); - } - //condition scope now - else { - val currNode = graph.getNode(i); - if(currNode.equals(from)) - continue; - - //break only after bootstrapping the first node (the predicate id node) - if(!seenNames.contains(graph.getNode(i).getName()) && !graph.getNode(i).getName().contains("pred_id")) { - break; - } - - /** - * Continuously add inputs seen for each node in the sub graph that occurs. - * Starting from the predicate id, any node that has inputs in the condition scope - * are by definition within the scope. Any node not encountered after that is considered out of scope. - * This means we break. - */ - for(int inputIdx = 0; inputIdx < currNode.getInputCount(); inputIdx++) { - seenNames.add(currNode.getInput(inputIdx)); - } - - - - //ensure the "current node" is added as well - seenNames.add(graph.getNode(i).getName()); - conditionNodes.add(graph.getNode(i)); - } - } - - /** - * Since we are going over the graph backwards, - * we need to reverse the nodes to ensure proper ordering. - */ - Collections.reverse(falseBodyNodes); - Collections.reverse(trueBodyNodes); - Collections.reverse(conditionNodes); - - - return IfImportState.builder() - .condNodes(conditionNodes) - .falseNodes(falseBodyNodes) - .trueNodes(trueBodyNodes) - .conditionBodyScopeName(falseDefScopeName) - .falseBodyScopeName(falseDefScopeName) - .trueBodyScopeName(trueDefScopeName) - .conditionBodyScopeName(scopeName) - .build(); + public static boolean isPlaceHolder(NodeDef nodeDef) { + return nodeDef.getOp().startsWith("Placeholder"); } - - - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java index 8f59a7ef7..39d8e1577 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java @@ -226,22 +226,24 @@ public class TensorFlowImportValidator { } public static TFImportStatus checkModelForImport(String path, InputStream is, boolean exceptionOnRead) throws IOException { - TFGraphMapper m = TFGraphMapper.getInstance(); try { int opCount = 0; Set opNames = new HashSet<>(); try(InputStream bis = new BufferedInputStream(is)) { - GraphDef graphDef = m.parseGraphFrom(bis); - List nodes = m.getNodeList(graphDef); + GraphDef graphDef = GraphDef.parseFrom(bis); + List nodes = new ArrayList<>(graphDef.getNodeCount()); + for( int i=0; i"; if (!isCompressed() && !preventUnpack) return options.format(this); else if (isCompressed() && compressDebug) @@ -5600,4 +5605,9 @@ public abstract class BaseNDArray implements INDArray, Iterable { return false; } + + @Override + public long getId(){ + return arrayId; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index a2860b582..221b4021b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -2814,4 +2814,10 @@ public interface INDArray extends Serializable, AutoCloseable { * @see org.nd4j.linalg.api.ndarray.BaseNDArray#toString(long, boolean, int) */ String toStringFull(); + + /** + * A unique ID for the INDArray object instance. Does not account for views. + * @return INDArray unique ID + */ + long getId(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java index 8c9cdf4e0..5c0577aca 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java @@ -24,6 +24,7 @@ import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -200,48 +201,17 @@ public abstract class BaseOp extends DifferentialFunction implements Op { @Override public void setX(INDArray x) { - if (x == null) { - if (args() != null && args().length >= 1) { - SDVariable firstArg = args()[0]; - if (firstArg.getArr() != null) - this.x = firstArg.getArr(); - } else - throw new ND4JIllegalStateException("Unable to set null array for x. Also unable to infer from differential function arguments"); - } else - this.x = x; + this.x = x; } @Override public void setZ(INDArray z) { - if (z == null) { - SDVariable getResult = sameDiff.getVariable(zVertexId); - if (getResult != null) { - if (getResult.getArr() != null) - this.z = getResult.getArr(); - else if(sameDiff.getShapeForVarName(getResult.getVarName()) != null) { - val shape = sameDiff.getShapeForVarName(getResult.getVarName()); - sameDiff.setArrayForVariable(getResult.getVarName(),getResult.getWeightInitScheme().create(getResult.dataType(), shape)); - } - else - throw new ND4JIllegalStateException("Unable to set null array for z. Also unable to infer from differential function arguments"); - - } else - throw new ND4JIllegalStateException("Unable to set null array for z. Also unable to infer from differential function arguments"); - } else - this.z = z; + this.z = z; } @Override public void setY(INDArray y) { - if (y == null) { - if (args() != null && args().length > 1) { - SDVariable firstArg = args()[1]; - if (firstArg.getArr() != null) - this.y = firstArg.getArr(); - } else - throw new ND4JIllegalStateException("Unable to set null array for y. Also unable to infer from differential function arguments"); - } else - this.y = y; + this.y = y; } @Override @@ -265,6 +235,12 @@ public abstract class BaseOp extends DifferentialFunction implements Op { return z; } + @Override + public INDArray getInputArgument(int index){ + Preconditions.checkState(index >= 0 && index < 2, "Input argument index must be 0 or 1, got %s", index); + return index == 0 ? x : y; + } + @Override public SDVariable[] outputVariables(String baseName) { if(zVertexId == null) { @@ -403,4 +379,11 @@ public abstract class BaseOp extends DifferentialFunction implements Op { //Always 1 for legacy/base ops return 1; } + + @Override + public void clearArrays(){ + x = null; + y = null; + z = null; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java index 10c26d29e..1e06e7f52 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java @@ -16,7 +16,6 @@ package org.nd4j.linalg.api.ops; -import org.nd4j.shade.guava.primitives.Ints; import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; @@ -24,21 +23,14 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; -import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; @@ -71,10 +63,6 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { this.keepDims = keepDims; this.xVertexId = i_v.getVarName(); sameDiff.addArgsFor(new String[]{xVertexId},this); - if(Shape.isPlaceholderShape(i_v.getShape())) { - sameDiff.addPropertyToResolve(this,i_v.getVarName()); - } - } else { throw new IllegalArgumentException("Input not null variable."); } @@ -219,14 +207,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - if (!attributesForNode.containsKey("axes")) { - this.dimensions = new int[] { Integer.MAX_VALUE }; - } - else { - val map = OnnxGraphMapper.getInstance().getAttrMap(node); - val dims = Ints.toArray(map.get("axes").getIntsList()); - this.dimensions = dims; - } + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java index 9deb230df..c228c9b8f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java @@ -119,4 +119,9 @@ public interface CustomOp { * otherwise throws an {@link org.nd4j.linalg.exception.ND4JIllegalStateException} */ void assertValidForExecution(); + + /** + * Clear the input and output INDArrays, if any are set + */ + void clearArrays(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index d2190098c..6ebaa5120 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -263,7 +263,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { @Override public INDArray[] outputArguments() { if (!outputArguments.isEmpty()) { - return outputArguments.toArray(new INDArray[outputArguments.size()]); + return outputArguments.toArray(new INDArray[0]); } return new INDArray[0]; } @@ -271,7 +271,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { @Override public INDArray[] inputArguments() { if (!inputArguments.isEmpty()) - return inputArguments.toArray(new INDArray[inputArguments.size()]); + return inputArguments.toArray(new INDArray[0]); return new INDArray[0]; } @@ -389,6 +389,13 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { } public void setInputArgument(int index, INDArray input) { + if(index >= inputArguments.size() ){ + List oldArgs = inputArguments; + inputArguments = new ArrayList<>(index+1); + inputArguments.addAll(oldArgs); + while(inputArguments.size() <= index) + inputArguments.add(null); + } inputArguments.set(index, input); } @@ -400,12 +407,12 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { } public void setOutputArgument(int index, INDArray output) { - if(index == outputArguments.size()){ - //For example, setOutputArgument(0,arr) on empty list - outputArguments.add(output); - } else { - outputArguments.set(index, output); + while(index >= outputArguments.size()){ + //Resize list, in case we want to specify arrays not in order they are defined + //For example, index 1 on empty list, then index 0 + outputArguments.add(null); } + outputArguments.set(index, output); } @Override @@ -608,6 +615,12 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { } + @Override + public void clearArrays(){ + inputArguments.clear(); + outputArguments.clear(); + } + protected static INDArray[] wrapOrNull(INDArray in){ return in == null ? null : new INDArray[]{in}; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Op.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Op.java index 3e5644439..ca0a816c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Op.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Op.java @@ -167,4 +167,9 @@ public interface Op { * @return the equivalent {@link CustomOp} */ CustomOp toCustomOp(); + + /** + * Clear the input and output INDArrays, if any are set + */ + void clearArrays(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java index 80f8c106d..5adfcbafd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java @@ -25,6 +25,6 @@ public class AdjustContrastV2 extends BaseAdjustContrast { @Override public String tensorflowName() { - return "AdjustContrast"; + return "AdjustContrastV2"; } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java index 0741e512e..cb805a775 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java @@ -245,4 +245,9 @@ public class ScatterUpdate implements CustomOp { public void assertValidForExecution() { } + + @Override + public void clearArrays() { + op.clearArrays(); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java index aed50c987..d8bf3f695 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java @@ -39,13 +39,18 @@ import java.util.*; @NoArgsConstructor public class BiasAdd extends DynamicCustomOp { + protected boolean nchw = true; - public BiasAdd(SameDiff sameDiff, SDVariable input, SDVariable bias) { + public BiasAdd(SameDiff sameDiff, SDVariable input, SDVariable bias, boolean nchw) { super(null, sameDiff, new SDVariable[] {input, bias}, false); + bArguments.clear(); + bArguments.add(nchw); } - public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output){ + public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){ super(new INDArray[]{input, bias}, wrapOrNull(output)); + bArguments.clear(); + bArguments.add(nchw); } @Override @@ -56,7 +61,11 @@ public class BiasAdd extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph); - + if(attributesForNode.containsKey("data_format")){ + nchw = "NCHW".equalsIgnoreCase(attributesForNode.get("data_format").getS().toStringUtf8()); + } + bArguments.clear(); + bArguments.add(nchw); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/If.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/If.java deleted file mode 100644 index 03dc26313..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/If.java +++ /dev/null @@ -1,402 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.controlflow; - -import lombok.*; -import lombok.extern.slf4j.Slf4j; -import onnx.Onnx; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.SameDiffConditional; -import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.CustomOp; -import org.nd4j.linalg.api.ops.CustomOpDescriptor; -import org.nd4j.linalg.api.ops.Op; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; -import org.nd4j.linalg.util.HashUtil; -import org.nd4j.weightinit.impl.ZeroInitScheme; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; - -import java.util.*; - -/** - * Equivalent to tensorflow's conditional op. - * Runs one of 2 {@link SameDiff.SameDiffFunctionDefinition} - * depending on a predicate {@link org.nd4j.autodiff.samediff.SameDiff.SameDiffConditional} - * - * - * @author Adam Gibson - */ -@NoArgsConstructor -@Slf4j -public class If extends DifferentialFunction implements CustomOp { - - @Getter - protected SameDiff loopBodyExecution,predicateExecution,falseBodyExecution; - - - @Getter - protected SameDiffConditional predicate; - @Getter - protected SameDiffFunctionDefinition trueBody,falseBody; - - @Getter - protected String blockName,trueBodyName,falseBodyName; - - @Getter - protected SDVariable[] inputVars; - - @Getter - protected Boolean trueBodyExecuted = null; - - @Getter - protected SDVariable targetBoolean; - - protected SDVariable dummyResult; - - @Getter - @Setter - protected SDVariable[] outputVars; - - public If(If ifStatement) { - this.sameDiff = ifStatement.sameDiff; - this.outputVars = ifStatement.outputVars; - this.falseBodyExecution = ifStatement.falseBodyExecution; - this.trueBodyExecuted = ifStatement.trueBodyExecuted; - this.falseBody = ifStatement.falseBody; - this.trueBodyExecuted = ifStatement.trueBodyExecuted; - this.dummyResult = ifStatement.dummyResult; - this.inputVars = ifStatement.inputVars; - this.dummyResult = this.sameDiff.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme(), DataType.FLOAT, 1); - if(sameDiff.getShapeForVarName(dummyResult.getVarName()) == null) - sameDiff.putShapeForVarName(dummyResult.getVarName(),new long[]{1,1}); - - - - - } - - @Builder - public If(String blockName, - SameDiff parent, - SDVariable[] inputVars, - SameDiffFunctionDefinition conditionBody, - SameDiffConditional predicate, - SameDiffFunctionDefinition trueBody, - SameDiffFunctionDefinition falseBody) { - - this.sameDiff = parent; - parent.putOpForId(getOwnName(),this); - this.inputVars = inputVars; - this.predicate = predicate; - - parent.addArgsFor(inputVars,this); - this.trueBody = trueBody; - this.falseBody = falseBody; - this.blockName = blockName; - //need to add the op to the list of ops to be executed when running backwards - this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1); - parent.addOutgoingFor(new SDVariable[]{dummyResult},this); - - //create a samediff sub graph for running just the execution - //return a reference to the loop for referencing during actual execution - SameDiff sameDiff = SameDiff.create(); - //store the reference to the result array and the same diff execution instance - this.targetBoolean = predicate.eval(sameDiff,conditionBody, inputVars); - this.predicateExecution = sameDiff; - //store references to the loop body - String trueBodyName = "true-body-" + UUID.randomUUID().toString(); - this.trueBodyName = trueBodyName; - - String falseBodyName = "false-body-" + UUID.randomUUID().toString(); - this.falseBodyName = trueBodyName; - - //running define function will setup a proper same diff instance - this.loopBodyExecution = parent.defineFunction(trueBodyName,trueBody,inputVars); - this.falseBodyExecution = parent.defineFunction(falseBodyName,falseBody,inputVars); - parent.defineFunction(blockName,conditionBody,inputVars); - parent.putSubFunction("predicate-eval-body-" + UUID.randomUUID().toString(),sameDiff); - //get a reference to the actual loop body - this.loopBodyExecution = parent.getFunction(trueBodyName); - } - - - /** - * Toggle whether the true body was executed - * or the false body - * @param trueBodyExecuted - */ - public void exectedTrueOrFalse(boolean trueBodyExecuted) { - if(trueBodyExecuted) - this.trueBodyExecuted = true; - else - this.trueBodyExecuted = false; - } - - - - @Override - public SDVariable[] outputVariables(String baseName) { - return new SDVariable[]{dummyResult}; - } - - @Override - public List doDiff(List f1) { - List ret = new ArrayList<>(); - ret.addAll(Arrays.asList(new IfDerivative(this).outputVariables())); - return ret; - } - - @Override - public String toString() { - return opName(); - } - - @Override - public String opName() { - return "if"; - } - - @Override - public long opHash() { - return HashUtil.getLongHash(opName()); - } - - @Override - public boolean isInplaceCall() { - return false; - } - - @Override - public INDArray[] outputArguments() { - return new INDArray[0]; - } - - @Override - public INDArray[] inputArguments() { - return new INDArray[0]; - } - - @Override - public long[] iArgs() { - return new long[0]; - } - - @Override - public double[] tArgs() { - return new double[0]; - } - - @Override - public boolean[] bArgs() { - return new boolean[0]; - } - - @Override - public void addIArgument(int... arg) { - - } - - @Override - public void addIArgument(long... arg) { - - } - - @Override - public void addBArgument(boolean... arg) { - - } - - @Override - public void removeIArgument(Integer arg) { - - } - - @Override - public Boolean getBArgument(int index) { - return null; - } - - @Override - public Long getIArgument(int index) { - return null; - } - - @Override - public int numIArguments() { - return 0; - } - - @Override - public void addTArgument(double... arg) { - - } - - @Override - public void removeTArgument(Double arg) { - - } - - @Override - public Double getTArgument(int index) { - return null; - } - - @Override - public int numTArguments() { - return 0; - } - - @Override - public int numBArguments() { - return 0; - } - - @Override - public void addInputArgument(INDArray... arg) { - - } - - @Override - public void removeInputArgument(INDArray arg) { - - } - - @Override - public INDArray getInputArgument(int index) { - return null; - } - - @Override - public int numInputArguments() { - return 0; - } - - @Override - public void addOutputArgument(INDArray... arg) { - - } - - @Override - public void removeOutputArgument(INDArray arg) { - - } - - @Override - public INDArray getOutputArgument(int index) { - return null; - } - - @Override - public int numOutputArguments() { - return 0; - } - - @Override - public Op.Type opType() { - return Op.Type.CONDITIONAL; - } - - @Override - public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - //cond is only part of while loops - if(nodeDef.getName().contains("/cond/")) - return; - //usually should be a merge node for a conditional - val ifNodes = TFGraphMapper.getInstance().nodesForIf(nodeDef,graph); - - - val trueScopeGraphDefBuilder = GraphDef.newBuilder(); - for(val node : ifNodes.getTrueNodes()) { - trueScopeGraphDefBuilder.addNode(node); - } - - - val trueScope = TFGraphMapper.getInstance().importGraph(trueScopeGraphDefBuilder.build()); - - - val falseScopeGraphDefBuilder = GraphDef.newBuilder(); - for(val node : ifNodes.getFalseNodes()) { - falseScopeGraphDefBuilder.addNode(node); - - } - - val falseScope = TFGraphMapper.getInstance().importGraph(falseScopeGraphDefBuilder.build()); - - - val condScopeGraphDefBuilder = GraphDef.newBuilder(); - for(val node : ifNodes.getCondNodes()) { - condScopeGraphDefBuilder.addNode(node); - - } - - - val condScope = TFGraphMapper.getInstance().importGraph(condScopeGraphDefBuilder.build()); - - - - initWith.putSubFunction(ifNodes.getTrueBodyScopeName(),trueScope); - initWith.putSubFunction(ifNodes.getFalseBodyScopeName(),falseScope); - initWith.putSubFunction(ifNodes.getConditionBodyScopeName(),condScope); - - this.loopBodyExecution = trueScope; - this.falseBodyExecution = falseScope; - this.predicateExecution = condScope; - } - - - @Override - public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - - } - - - - @Override - public List calculateOutputShape() { - return Arrays.asList(LongShapeDescriptor.fromShape(new long[0], DataType.BOOL)); - } - - @Override - public CustomOpDescriptor getDescriptor() { - return null; - } - - @Override - public void assertValidForExecution() { - - } - - - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("This operation has no TF counterpart"); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/IfDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/IfDerivative.java deleted file mode 100644 index 77b2eafa1..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/IfDerivative.java +++ /dev/null @@ -1,93 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.controlflow; - -import lombok.NoArgsConstructor; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.SameDiffConditional; -import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; - -import java.util.List; - -@NoArgsConstructor -public class IfDerivative extends If { - - private If ifDelegate; - - public IfDerivative(If ifBlock) { - super(ifBlock); - this.ifDelegate = ifBlock; - } - - @Override - public Boolean getTrueBodyExecuted() { - return ifDelegate.trueBodyExecuted; - } - - - @Override - public SameDiffFunctionDefinition getFalseBody() { - return ifDelegate.falseBody; - } - - @Override - public SameDiff getFalseBodyExecution() { - return ifDelegate.falseBodyExecution; - } - - @Override - public String getBlockName() { - return ifDelegate.blockName; - } - - @Override - public String getFalseBodyName() { - return ifDelegate.falseBodyName; - } - - @Override - public SameDiff getLoopBodyExecution() { - return ifDelegate.loopBodyExecution; - } - - @Override - public SameDiffConditional getPredicate() { - return ifDelegate.getPredicate(); - } - - @Override - public SameDiff getPredicateExecution() { - return ifDelegate.predicateExecution; - } - - @Override - public List calculateOutputShape() { - return super.calculateOutputShape(); - } - - @Override - public String opName() { - return "if_bp"; - } - - @Override - public List diff(List i_v1) { - throw new UnsupportedOperationException("Unable to take the derivative of the derivative for if"); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/IfImportState.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/IfImportState.java deleted file mode 100644 index e81016426..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/IfImportState.java +++ /dev/null @@ -1,32 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.controlflow; - -import lombok.Builder; -import lombok.Data; -import org.tensorflow.framework.NodeDef; - -import java.util.List; - -@Builder -@Data -public class IfImportState { - private List condNodes; - private List trueNodes; - private List falseNodes; - private String falseBodyScopeName,trueBodyScopeName,conditionBodyScopeName; -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/Select.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/Select.java index 2a3403ae8..5fdebd03d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/Select.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/Select.java @@ -55,7 +55,7 @@ public class Select extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/While.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/While.java deleted file mode 100644 index e26b0ea5f..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/While.java +++ /dev/null @@ -1,660 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.controlflow; - -import lombok.*; -import lombok.extern.slf4j.Slf4j; -import onnx.Onnx; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.SameDiffConditional; -import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.imports.converters.DifferentialFunctionClassHolder; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.CustomOp; -import org.nd4j.linalg.api.ops.CustomOpDescriptor; -import org.nd4j.linalg.api.ops.Op; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; -import org.nd4j.linalg.exception.ND4JIllegalArgumentException; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.weightinit.impl.ZeroInitScheme; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; - -import java.util.*; -import java.util.concurrent.atomic.AtomicInteger; - -/** - * Equivalent to tensorflow's while loop - * Takes in: - * loopVars - * loop body - * condition - * - * runs loop till condition is false. - * @author Adam Gibson - */ -@NoArgsConstructor -@Slf4j -public class While extends DifferentialFunction implements CustomOp { - private AtomicInteger startPosition; - - - - @Getter - protected SameDiff loopBodyExecution,predicateExecution; - - - @Getter - protected SameDiffConditional predicate; - @Getter - protected SameDiffFunctionDefinition trueBody; - - @Getter - protected String blockName,trueBodyName; - - @Getter - protected SDVariable[] inputVars; - - - @Getter - protected SDVariable targetBoolean; - - protected SDVariable dummyResult; - - @Getter - @Setter - protected SDVariable[] outputVars; - - @Getter - protected int numLooped = 0; - - /** - * Mainly meant for tensorflow import. - * This allows {@link #initFromTensorFlow(NodeDef, SameDiff, Map, GraphDef)} - * to continue from a parent while loop - * using the same graph - * @param startPosition the start position for the import scan - */ - public While(AtomicInteger startPosition) { - this.startPosition = startPosition; - } - - public While(While whileStatement) { - this.sameDiff = whileStatement.sameDiff; - this.outputVars = whileStatement.outputVars; - this.loopBodyExecution = whileStatement.loopBodyExecution; - this.numLooped = whileStatement.numLooped; - this.dummyResult = whileStatement.dummyResult; - this.predicate = whileStatement.predicate; - this.predicateExecution = whileStatement.predicateExecution; - this.inputVars = whileStatement.inputVars; - this.dummyResult = this.sameDiff.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1); - } - - - - @Builder - public While(String blockName, - SameDiff parent, - SDVariable[] inputVars, - SameDiffConditional predicate, - SameDiffFunctionDefinition condition, - SameDiffFunctionDefinition trueBody) { - init(blockName,parent,inputVars,predicate,condition,trueBody); - } - - - private void init(String blockName, - SameDiff parent, - SDVariable[] inputVars, - SameDiffConditional predicate, - SameDiffFunctionDefinition condition, - SameDiffFunctionDefinition trueBody) { - this.sameDiff = parent; - this.inputVars = inputVars; - this.predicate = predicate; - this.trueBody = trueBody; - this.blockName = blockName; - this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1); - parent.putOpForId(getOwnName(),this); - - parent.addArgsFor(inputVars,this); - parent.addOutgoingFor(new SDVariable[]{dummyResult},this); - - - //create a samediff sub graph for running just the execution - //return a reference to the loop for referencing during actual execution - SameDiff sameDiff = SameDiff.create(); - //store the reference to the result array and the same diff execution instance - this.targetBoolean = predicate.eval(sameDiff,condition, inputVars); - this.predicateExecution = sameDiff; - //store references to the loop body - String trueBodyName = "true-body-" + UUID.randomUUID().toString(); - this.trueBodyName = trueBodyName; - //running define function will setup a proper same diff instance - parent.defineFunction(trueBodyName,trueBody,inputVars); - parent.defineFunction(blockName,condition,inputVars); - parent.putSubFunction("predicate-eval-body",sameDiff); - //get a reference to the actual loop body - this.loopBodyExecution = parent.getFunction(trueBodyName); - - } - - - @Override - public SDVariable[] outputVariables(String baseName) { - return new SDVariable[]{dummyResult}; - } - - @Override - public List doDiff(List f1) { - List ret = new ArrayList<>(); - ret.addAll(Arrays.asList(new WhileDerivative(this).outputVariables())); - return ret; - } - - - - /** - * Increments the loop counter. - * This should be called when the loop - * actually executes. - */ - public void incrementLoopCounter() { - numLooped++; - } - - @Override - public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - doImport(nodeDef,initWith,attributesForNode,graph,new LinkedHashSet(),new AtomicInteger(0)); - } - - - private void doImport(NodeDef nodeDef,SameDiff initWith,Map attributesForNode,GraphDef graph,Set skipSet,AtomicInteger currIndex) { - val uniqueId = java.util.UUID.randomUUID().toString(); - skipSet.add(nodeDef.getName()); - val scopeCondition = SameDiff.create(); - val scopeLoop = SameDiff.create(); - initWith.putSubFunction("condition-" + uniqueId,scopeCondition); - initWith.putSubFunction("loopbody-" + uniqueId,scopeLoop); - this.loopBodyExecution = scopeLoop; - this.predicateExecution = scopeCondition; - this.startPosition = currIndex; - - log.info("Adding 2 new scopes for WHILE {}"); - - - val nodes = graph.getNodeList(); - - /** - * Plan is simple: - * 1) we read all declarations of variables used within loop - * 2) we set up conditional scope - * 3) we set up body scope - * 4) ??? - * 5) PROFIT! - */ - - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - if (!tfNode.getOp().equalsIgnoreCase("enter")) { - //skipSet.add(tfNode.getName()); - break; - } - -// if (skipSet.contains(tfNode.getName())) -// continue; - - skipSet.add(tfNode.getName()); - - val vars = new SDVariable[tfNode.getInputCount()]; - for (int e = 0; e < tfNode.getInputCount(); e++) { - val input = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(e)); - vars[e] = initWith.getVariable(input) == null ? initWith.var(input, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(input); - scopeCondition.var(vars[e]); - scopeLoop.var(vars[e]); - } - - this.inputVars = vars; - } - - - // now we're skipping Merge step, since we've already captured variables at Enter step - int mergedCnt = 0; - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - if (!tfNode.getOp().equalsIgnoreCase("merge")) { - scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), (LongShapeDescriptor) null,new ZeroInitScheme()); - break; - } - - skipSet.add(tfNode.getName()); - val var = scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), (LongShapeDescriptor)null,new ZeroInitScheme()); - scopeCondition.var(var); - initWith.var(var); - mergedCnt++; - } - - - // now, we're adding conditional scope - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - // we're parsing up to condition - if (tfNode.getOp().equalsIgnoreCase("LoopCond")) { - skipSet.add(tfNode.getName()); - currIndex.incrementAndGet(); - break; - } - - boolean isConst = tfNode.getOp().equalsIgnoreCase("const"); - boolean isVar = tfNode.getOp().startsWith("VariableV"); - boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder"); - - - if (isConst || isVar || isPlaceholder) { - val var = scopeCondition.var(tfNode.getName(), (LongShapeDescriptor) null,new ZeroInitScheme()); - scopeLoop.var(var); - initWith.var(var); - log.info("Adding condition var [{}]", var.getVarName()); - - } - else if(!skipSet.contains(tfNode.getName())) { - val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName()); - func.initFromTensorFlow(tfNode,scopeCondition,nodeDef.getAttrMap(),graph); - func.setSameDiff(scopeLoop); - - } - - skipSet.add(tfNode.getName()); - } - - - - // time to skip some Switch calls - int switchCnt = 0; - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - // we're parsing up to condition - if (!tfNode.getOp().equalsIgnoreCase("Switch")) - break; - - switchCnt++; - skipSet.add(tfNode.getName()); - } - - // now we're parsing Identity step - int identityCnt = 0; - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - - if (!tfNode.getOp().equalsIgnoreCase("Identity")) { - break; - } - - - val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName()); - func.initFromTensorFlow(tfNode,initWith,nodeDef.getAttrMap(),graph); - func.setSameDiff(scopeLoop); - - - val variables = new SDVariable[tfNode.getInputCount()]; - for(int i = 0; i < tfNode.getInputCount(); i++) { - val testVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i))); - if(testVar == null) { - variables[i] = initWith.var(tfNode.getInput(i), (LongShapeDescriptor) null,new ZeroInitScheme()); - scopeCondition.var(variables[i]); - scopeLoop.var(variables[i]); - continue; - } - else { - - variables[i] = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i))); - scopeCondition.var(variables[i]); - scopeLoop.var(variables[i]); - } - - } - - scopeLoop.addArgsFor(variables,func); - skipSet.add(tfNode.getName()); - } - - - // parsing body scope - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - if (skipSet.contains(tfNode.getName())) { - log.info("Skipping: {}", tfNode.getName()); - continue; - } - - if (tfNode.getOp().equalsIgnoreCase("NextIteration")) { -// skipSet.add(tfNode.getName()); - break; - } - - if (skipSet.contains(tfNode.getName())) { - log.info("Skipping: {}", tfNode.getName()); - continue; - } - - - - boolean isConst = tfNode.getOp().equalsIgnoreCase("const"); - boolean isVar = tfNode.getOp().startsWith("VariableV"); - boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder"); - - - if (isConst || isVar || isPlaceholder) { - val var = scopeLoop.var(tfNode.getName(), (LongShapeDescriptor) null,new ZeroInitScheme()); - log.info("Adding body var [{}]",var.getVarName()); - - } else { - log.info("starting on [{}]: {}", tfNode.getName(), tfNode.getOp()); - - if (tfNode.getOp().equalsIgnoreCase("enter")) { - log.info("NEW LOOP ----------------------------------------"); - val func = new While(currIndex); - func.doImport(nodeDef,initWith,attributesForNode,graph,skipSet,currIndex); - func.setSameDiff(initWith); - log.info("END LOOP ----------------------------------------"); - } else { - val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName()); - - func.initFromTensorFlow(tfNode,initWith,nodeDef.getAttrMap(),graph); - - - func.setSameDiff(scopeCondition); - - val variables = new SDVariable[tfNode.getInputCount()]; - for(int i = 0; i < tfNode.getInputCount(); i++) { - val name = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)); - variables[i] = scopeCondition.getVariable(name); - if(variables[i] == null) { - if(scopeLoop.getVariable(name) == null) - variables[i] = scopeCondition.var(initWith.getVariable(name)); - else if(scopeLoop.getVariable(name) != null) - variables[i] = scopeLoop.getVariable(name); - else - variables[i] = scopeLoop.var(name, Nd4j.scalar(1.0)); - } - } - - scopeLoop.addArgsFor(variables,func); - - - } - } - - skipSet.add(tfNode.getName()); - } - - - val returnInputs = new ArrayList(); - val returnOutputs = new ArrayList(); - - // mapping NextIterations, to Return op - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - if (!tfNode.getOp().equalsIgnoreCase("NextIteration")) - break; - - skipSet.add(tfNode.getName()); - - val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName()); - val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(inputName) ; - returnInputs.add(input); - } - - - this.outputVars = returnOutputs.toArray(new SDVariable[returnOutputs.size()]); - this.inputVars = returnInputs.toArray(new SDVariable[returnInputs.size()]); - initWith.addArgsFor(inputVars,this); - initWith.addOutgoingFor(outputVars,this); - - // we should also map While/Exit to libnd4j while - int exitCnt = 0; - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - if (!tfNode.getOp().equalsIgnoreCase("Exit")) { - //skipSet.add(tfNode.getName()); - break; - } - - skipSet.add(tfNode.getName()); - val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName()); - val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(inputName) ; - } - - - //the output of the condition should always be a singular scalar - //this is a safe assumption - val conditionVars = scopeCondition.ops(); - if(conditionVars.length < 1) { - throw new ND4JIllegalArgumentException("No functions found!"); - } - this.targetBoolean = conditionVars[conditionVars.length - 1].outputVariables()[0]; - - log.info("-------------------------------------------"); - - } - - @Override - public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - - } - - - @Override - public String toString() { - return opName(); - } - - @Override - public String opName() { - return "while"; - } - - @Override - public long opHash() { - return opName().hashCode(); - } - - @Override - public boolean isInplaceCall() { - return false; - } - - @Override - public INDArray[] outputArguments() { - return new INDArray[0]; - } - - @Override - public INDArray[] inputArguments() { - return new INDArray[0]; - } - - @Override - public long[] iArgs() { - return new long[0]; - } - - @Override - public double[] tArgs() { - return new double[0]; - } - - @Override - public void addIArgument(int... arg) { - - } - - @Override - public void addIArgument(long... arg) { - - } - - @Override - public void removeIArgument(Integer arg) { - - } - - @Override - public Long getIArgument(int index) { - return null; - } - - @Override - public int numIArguments() { - return 0; - } - - @Override - public void addTArgument(double... arg) { - - } - - @Override - public void removeTArgument(Double arg) { - - } - - @Override - public Double getTArgument(int index) { - return null; - } - - @Override - public int numTArguments() { - return 0; - } - - @Override - public int numBArguments() { - return 0; - } - - @Override - public void addInputArgument(INDArray... arg) { - - } - - @Override - public void removeInputArgument(INDArray arg) { - - } - - @Override - public boolean[] bArgs() { - return new boolean[0]; - } - - @Override - public void addBArgument(boolean... arg) { - - } - - @Override - public Boolean getBArgument(int index) { - return null; - } - - @Override - public INDArray getInputArgument(int index) { - return null; - } - - @Override - public int numInputArguments() { - return 0; - } - - @Override - public void addOutputArgument(INDArray... arg) { - - } - - @Override - public void removeOutputArgument(INDArray arg) { - - } - - @Override - public INDArray getOutputArgument(int index) { - return null; - } - - @Override - public int numOutputArguments() { - return 0; - } - @Override - public List calculateOutputShape() { - List ret = new ArrayList<>(); - for(SDVariable var : args()) { - ret.add(sameDiff.getShapeDescriptorForVarName(var.getVarName())); - } - return ret; - } - - @Override - public CustomOpDescriptor getDescriptor() { - return CustomOpDescriptor.builder().build(); - } - - @Override - public void assertValidForExecution() { - - } - - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No *singular (eg: use tensorflowNames() found for this op " + opName()); - } - - @Override - public String[] tensorflowNames() { - throw new NoOpNameFoundException("This operation has no TF counterpart"); - } - - - @Override - public Op.Type opType() { - return Op.Type.LOOP; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/WhileDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/WhileDerivative.java deleted file mode 100644 index d9aaf2af0..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/WhileDerivative.java +++ /dev/null @@ -1,96 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.controlflow; - -import lombok.NoArgsConstructor; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.SameDiffConditional; -import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ops.Op; - -/** - * While loop derivative - * @author Adam Gibson - */ -@NoArgsConstructor -public class WhileDerivative extends While { - private While delegate; - - public WhileDerivative(While delegate) { - super(delegate); - this.delegate = delegate; - } - - - - @Override - public SameDiffFunctionDefinition getTrueBody() { - return delegate.trueBody; - } - - @Override - public String getTrueBodyName() { - return delegate.getTrueBodyName(); - } - - @Override - public SameDiffConditional getPredicate() { - return delegate.getPredicate(); - } - - @Override - public SameDiff getPredicateExecution() { - return delegate.getPredicateExecution(); - } - - @Override - public SDVariable[] getInputVars() { - return delegate.getInputVars(); - } - - @Override - public String getBlockName() { - return delegate.getBlockName(); - } - - @Override - public SameDiff getLoopBodyExecution() { - return delegate.getLoopBodyExecution(); - } - - @Override - public int getNumLooped() { - return delegate.getNumLooped(); - } - - @Override - public String opName() { - return "while_bp"; - } - - @Override - public Op.Type opType() { - return Op.Type.CONDITIONAL; - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow name for while backprop"); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java index 3f56096a2..1bb451bf1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java @@ -55,7 +55,7 @@ public abstract class BaseCompatOp extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode,nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode,nodeDef, graph); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java index 4f5d11b38..769f7c509 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java @@ -32,9 +32,11 @@ import java.util.List; import java.util.Map; public class LoopCond extends BaseCompatOp { + public static final String OP_NAME = "loop_cond"; + @Override public String opName() { - return "loop_cond"; + return OP_NAME; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java index 4ede302dd..0fee6c238 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java @@ -74,8 +74,6 @@ public class CropAndResize extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); - String method = attributesForNode.get("method").getS().toStringUtf8(); if(method.equalsIgnoreCase("nearest")){ this.method = Method.NEAREST; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java index 62194c044..8922df9e5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java @@ -120,4 +120,10 @@ public class ExtractImagePatches extends DynamicCustomOp { //TF includes redundant leading and training 1s for kSizes, strides, rates (positions 0/3) return new int[]{(int)ilist.getI(1), (int)ilist.getI(2)}; } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatypes for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java index 5ae8f85ea..be6eb3730 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java @@ -74,7 +74,7 @@ public class ResizeBilinear extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); this.alignCorners = attributesForNode.get("align_corners").getB(); addArgs(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeNearestNeighbor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeNearestNeighbor.java index ecb48f922..ea339ae2c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeNearestNeighbor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeNearestNeighbor.java @@ -50,7 +50,7 @@ public class ResizeNearestNeighbor extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java index bad975cb5..a8c50abdf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java @@ -26,8 +26,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.base.Preconditions; -import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,7 +39,6 @@ import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.lang.reflect.Field; import java.util.*; @@ -106,7 +103,7 @@ public class BatchNorm extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); //Switch order: TF uses [input, gamma, beta, mean, variance]; libnd4j expects [input, mean, variance, gamma, beta] SameDiffOp op = initWith.getOps().get(this.getOwnName()); List list = op.getInputsToOp(); @@ -140,8 +137,7 @@ public class BatchNorm extends DynamicCustomOp { @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); - addArgs(); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java index 2fc814fb3..852c865f7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java @@ -21,33 +21,20 @@ import lombok.Getter; import lombok.NoArgsConstructor; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; -import lombok.val; -import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.imports.converters.DifferentialFunctionClassHolder; -import org.nd4j.imports.descriptors.properties.AttributeAdapter; -import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter; -import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter; -import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater; -import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.util.ArrayUtil; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; import java.lang.reflect.Field; -import java.util.*; +import java.util.Collections; +import java.util.List; +import java.util.Map; /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java index 5e077e3fc..3794469ae 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java @@ -31,7 +31,6 @@ import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.adapters.*; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -122,7 +121,7 @@ public class Conv2D extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } @@ -138,8 +137,7 @@ public class Conv2D extends DynamicCustomOp { @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); - addArgs(); + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java index 8c4e40e8a..665e7dd99 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java @@ -251,7 +251,7 @@ public class Conv3D extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java index c69292dd9..dc92d826d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java @@ -198,7 +198,7 @@ public class DeConv2D extends DynamicCustomOp { val args = args(); INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr(); if (arr == null) { - arr = TFGraphMapper.getInstance().getNDArrayFromTensor(nodeDef.getInput(0), nodeDef, graph); + arr = TFGraphMapper.getNDArrayFromTensor(nodeDef); // TODO: arguable. it might be easier to permute weights once //arr = (arr.permute(3, 2, 0, 1).dup('c')); val varForOp = initWith.getVariable(args[1].getVarName()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java index bc4f996b1..dfabc89dd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java @@ -214,7 +214,7 @@ public class DeConv2DTF extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } @@ -240,9 +240,9 @@ public class DeConv2DTF extends DynamicCustomOp { } @Override - public List calculateOutputDataTypes(List inputDataTypes){ + public List calculateOutputDataTypes(List inputDataTypes){ //inShape, weights, input int n = args().length; Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); - return Collections.singletonList(inputDataTypes.get(0)); + return Collections.singletonList(inputDataTypes.get(2)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java index 077f6a64b..6a3c8854f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java @@ -160,7 +160,7 @@ public class DeConv3D extends DynamicCustomOp { val args = args(); INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr(); if (arr == null) { - arr = TFGraphMapper.getInstance().getNDArrayFromTensor(nodeDef.getInput(0), nodeDef, graph); + arr = TFGraphMapper.getNDArrayFromTensor(nodeDef); val varForOp = initWith.getVariable(args[1].getVarName()); if (arr != null) initWith.associateArrayWithVariable(arr, varForOp); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java index 6715f742a..704f8bdd4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java @@ -77,7 +77,7 @@ public class DepthToSpace extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); boolean isNHWC = dataFormat.equals("NHWC"); addIArgument(blockSize, isNHWC ? 1 : 0); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java index ec2bb1d3f..4b10909a0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java @@ -29,14 +29,15 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.descriptors.properties.adapters.*; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; +import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter; +import org.nd4j.imports.descriptors.properties.adapters.NDArrayShapeAdapter; +import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater; +import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -136,7 +137,7 @@ public class DepthwiseConv2D extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); /* @@ -162,8 +163,7 @@ public class DepthwiseConv2D extends DynamicCustomOp { @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); - addArgs(); + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java index 90bdcdb45..e591c9f1c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java @@ -75,7 +75,7 @@ public class SpaceToDepth extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); boolean isNHWC = dataFormat == null ? true : dataFormat.equals("NHWC"); addIArgument(blockSize, isNHWC ? 1 : 0); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java index 3b12187e3..21756d99b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java @@ -64,7 +64,7 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java index 9c385c425..15f556c64 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java @@ -55,7 +55,7 @@ public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); //Switch order: TF uses [logits, labels]; libnd4j expects [labels, logits] SameDiffOp op = initWith.getOps().get(this.getOwnName()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java index 80b767676..7afdf5166 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java @@ -64,7 +64,7 @@ public class Moments extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java index 09f4ac2f4..36b4600c9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java @@ -60,7 +60,7 @@ public class NormalizeMoments extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java index ad19598eb..6c8aa5901 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java @@ -63,7 +63,7 @@ public class ScatterAdd extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java index ea7ef3da7..4e7563e4a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java @@ -86,7 +86,7 @@ public class ScatterDiv extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java index 33f8db980..65162aad3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java @@ -60,7 +60,7 @@ public class ScatterMax extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java index 00322b259..8d8fe4e33 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java @@ -60,7 +60,7 @@ public class ScatterMin extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java index 1db426364..2790667cd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java @@ -62,7 +62,7 @@ public class ScatterMul extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNd.java index a589fa1ae..a72801760 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNd.java @@ -67,7 +67,7 @@ public class ScatterNd extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { @@ -80,8 +80,8 @@ public class ScatterNd extends DynamicCustomOp { } @Override - public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), inputDataTypes); + public List calculateOutputDataTypes(List inputDataTypes){ //Indices, updates, shape + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(1)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdAdd.java index 7dd2b9462..c79ec058d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdAdd.java @@ -66,7 +66,7 @@ public class ScatterNdAdd extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java index 42c539f58..8efc6717f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java @@ -66,7 +66,7 @@ public class ScatterNdSub extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java index aeb3c9872..bf95b448d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java @@ -66,7 +66,7 @@ public class ScatterNdUpdate extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java index 375d5bc6b..382806779 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java @@ -79,7 +79,7 @@ public class ScatterSub extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java index ccfc541de..980ae7f8c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java @@ -73,8 +73,6 @@ public class ScatterUpdate extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); - if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { bArguments.add(true); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java index 5c6beb945..d40a3a334 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java @@ -151,6 +151,7 @@ public class Concat extends DynamicCustomOp { removeInputArgument(inputArgs[inputArguments().length - 1]); } + //TODO Fix this: https://github.com/eclipse/deeplearning4j/issues/8285 sameDiff.removeArgFromOp(input,this); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java index 5c50b983d..a13a03184 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java @@ -69,8 +69,8 @@ public class ExpandDims extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - val targetNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, nodeDef.getInput(1)); - val dimArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", targetNode, graph); + val targetNode = TFGraphMapper.getNodeWithNameFromGraph(graph, nodeDef.getInput(1)); + val dimArr = TFGraphMapper.getNDArrayFromTensor(targetNode); if (dimArr != null) { int axis = dimArr.data().asInt()[0]; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java index 5613cc85f..fd6ec5240 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java @@ -22,13 +22,9 @@ import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -73,12 +69,12 @@ public class Gather extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); } @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java index cfe4fe8be..b8ef51d57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java @@ -17,26 +17,13 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.NoArgsConstructor; -import lombok.val; -import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.util.ArrayUtil; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; /** * GatherND op diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java index 2d5dcc63c..841fec7b0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java @@ -88,7 +88,7 @@ public class OneHot extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); if(attributesForNode.containsKey("T")) { outputType = TFGraphMapper.convertType(attributesForNode.get("T").getType()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java index 1856e6804..3a1605d8b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java @@ -64,7 +64,7 @@ public class ParallelStack extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java index 568b14a44..c05df441c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java @@ -50,21 +50,6 @@ public class Rank extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {input}, inPlace); } - - @Override - public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - val name = TFGraphMapper.getInstance().getNodeName(nodeDef.getName()); - val input = initWith.getVariable(name); - val outputVertex = input.getVarName(); - if (!initWith.isPlaceHolder(input.getVarName()) && initWith.shapeAlreadyExistsForVarName(outputVertex)) { - val inputShape = initWith.getShapeForVarName(input.getVarName()); - val resultLength = Nd4j.scalar(inputShape.length); - val thisResultId = outputVertex; - initWith.setArrayForVariable(thisResultId, resultLength); - initWith.putShapeForVarName(thisResultId, new long[]{1, 1}); - } - } - @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java index af8940bf4..4bf920da9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java @@ -101,7 +101,7 @@ public class Repeat extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addIArgument(jaxis); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index b30bacc22..44d9b79fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -21,20 +21,18 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.*; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; /** * Reshape function @@ -70,32 +68,7 @@ public class Reshape extends DynamicCustomOp { if (!nodeDef.containsAttr("TShape") && nodeDef.getInputCount() == 1) { this.shape = new long[]{}; return; - } else if (nodeDef.getInputCount() > 1) { - val shapeNode = nodeDef.getInput(1); - NodeDef shapeNodeInGraph = null; - for (int i = 0; i < graph.getNodeCount(); i++) { - if (graph.getNode(i).getName().equals(shapeNode)) { - shapeNodeInGraph = graph.getNode(i); - - } - } - - val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", shapeNodeInGraph, graph); - if (arr != null && arr.isEmpty()) { - // special case: empty array - this.shape = new long[0]; - - } else if (arr != null) { - this.shape = arr.data().asLong(); - //all TF is c - if (!ArrayUtil.containsAnyNegative(this.shape)) - addIArgument(this.shape); - else { - arrName = nodeDef.getName(); - } - - } - } else { + } else if(nodeDef.getInputCount() == 1){ val shape = nodeDef.getAttrOrThrow("Tshape"); if (!shape.hasShape()) { val shapeRet = new long[2]; @@ -127,8 +100,7 @@ public class Reshape extends DynamicCustomOp { @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - val shape = new OnnxGraphMapper().getShape(node); - this.shape = shape; + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java index a2f6bd208..67454e231 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java @@ -65,13 +65,13 @@ public class SequenceMask extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - val targetNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, nodeDef.getInput(1)); - val maxlen = TFGraphMapper.getInstance().getNDArrayFromTensor("value", targetNode, graph); + val targetNode = TFGraphMapper.getNodeWithNameFromGraph(graph, nodeDef.getInput(1)); + val maxlen = TFGraphMapper.getNDArrayFromTensor(targetNode); if (maxlen == null){ // No 2nd input this.is_static_maxlen = true; } - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (is_static_maxlen) { addIArgument(this.maxLen); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Split.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Split.java index f11c10c1c..685623d32 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Split.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Split.java @@ -54,7 +54,7 @@ public class Split extends DynamicCustomOp { this.numSplit = numSplits; addIArgument(numSplits); - val splitDim = TFGraphMapper.getInstance().getArrayFrom(TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,nodeDef.getInput(0)),graph); + val splitDim = TFGraphMapper.getArrayFrom(TFGraphMapper.getNodeWithNameFromGraph(graph,nodeDef.getInput(0)),graph); if(splitDim != null) { this.splitDim = splitDim.getInt(0); addIArgument(splitDim.getInt(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SplitV.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SplitV.java index 134f0dbe3..2407bc0f1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SplitV.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SplitV.java @@ -49,7 +49,7 @@ public class SplitV extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - val splitDim = TFGraphMapper.getInstance().getArrayFrom(TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,nodeDef.getInput(0)),graph); + val splitDim = TFGraphMapper.getArrayFrom(TFGraphMapper.getNodeWithNameFromGraph(graph,nodeDef.getInput(0)),graph); if(splitDim != null) { this.splitDim = splitDim.getInt(0); addIArgument(splitDim.getInt(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java index 6cd09f9bd..d2bf9d71b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java @@ -88,7 +88,7 @@ public class Stack extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java index 782c70859..41224fbb7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java @@ -114,7 +114,7 @@ public class Transpose extends DynamicCustomOp { } - INDArray permuteArrayOp = TFGraphMapper.getInstance().getNDArrayFromTensor("value", permuteDimsNode, graph); + INDArray permuteArrayOp = TFGraphMapper.getNDArrayFromTensor(permuteDimsNode); if (permuteArrayOp != null) { this.permuteDims = permuteArrayOp.data().asInt(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java index b027750fc..a9e67f9f6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java @@ -47,8 +47,8 @@ public abstract class BaseTensorOp extends DynamicCustomOp { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { val inputOne = nodeDef.getInput(1); val varFor = initWith.getVariable(inputOne); - val nodeWithIndex = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,inputOne); - val var = TFGraphMapper.getInstance().getArrayFrom(nodeWithIndex,graph); + val nodeWithIndex = TFGraphMapper.getNodeWithNameFromGraph(graph,inputOne); + val var = TFGraphMapper.getArrayFrom(nodeWithIndex,graph); if(var != null) { val idx = var.getInt(0); addIArgument(idx); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArray.java index b22434a71..4ecdf947d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArray.java @@ -70,7 +70,7 @@ public class TensorArray extends BaseTensorOp { } } - val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",iddNode,graph); + val arr = TFGraphMapper.getNDArrayFromTensor(iddNode); if (arr != null) { int idx = arr.getInt(0); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java index 4d321d920..9a935aa8f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java @@ -18,12 +18,15 @@ package org.nd4j.linalg.api.ops.impl.transforms; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -46,11 +49,17 @@ public class Cholesky extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); } @Override public List doDiff(List f1) { throw new UnsupportedOperationException(); } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatypes for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/NthElement.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/NthElement.java index 958df5579..fcf6390cf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/NthElement.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/NthElement.java @@ -18,12 +18,15 @@ package org.nd4j.linalg.api.ops.impl.transforms; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -47,7 +50,7 @@ public class NthElement extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); this.reverse = attributesForNode.get("reverse").getB(); addArgs(); @@ -70,4 +73,10 @@ public class NthElement extends DynamicCustomOp { public List doDiff(List f1) { throw new UnsupportedOperationException(); } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ //Input and number + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java index 72d2823b5..7ea1bb38b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java @@ -99,8 +99,8 @@ public class Pad extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2), - "Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() >= 1 && inputDataTypes.size() <= 3), + "Expected 1-3 input datatypes for %s, got %s", getClass(), inputDataTypes); //input, padding, pad value return Collections.singletonList(inputDataTypes.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java index 9c04aeb12..3874c040b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java @@ -120,7 +120,7 @@ public class CumProd extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } @@ -143,7 +143,8 @@ public class CumProd extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List dataTypes){ - Preconditions.checkState(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), dataTypes); + Preconditions.checkState(dataTypes != null && (dataTypes.size() == 1 || dataTypes.size() == 2), + "Expected 1 or 2 input datatype for %s, got %s", getClass(), dataTypes); //2nd optional input - axis return Collections.singletonList(dataTypes.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java index b8c7d5c51..6720b5a75 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java @@ -122,7 +122,7 @@ public class CumSum extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } @@ -144,7 +144,8 @@ public class CumSum extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List dataTypes){ - Preconditions.checkState(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), dataTypes); + Preconditions.checkState(dataTypes != null && (dataTypes.size() == 1 || dataTypes.size() == 2), + "Expected 1 or 2 input datatype for %s, got %s", getClass(), dataTypes); //2nd optional input - axis return Collections.singletonList(dataTypes.get(0)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java index 79e793174..fc909261b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java @@ -19,12 +19,14 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.adapters.*; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.util.ArrayUtil; @@ -32,9 +34,7 @@ import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.Map; +import java.util.*; /** * Dilation2D op wrapper @@ -90,7 +90,7 @@ public class Dilation2D extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode,nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode,nodeDef, graph); addArgs(); } @@ -185,4 +185,11 @@ public class Dilation2D extends DynamicCustomOp { public String tensorflowName() { return "Dilation2D"; } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ //Input and weights, optional rates/strides + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() >= 2 && inputDataTypes.size() <= 4, + "Expected 2 to 4 input datatypes for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java index 06d52f777..8581c51fb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java @@ -74,7 +74,7 @@ public class DynamicPartition extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java index af4097870..a5ffbced5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java @@ -54,7 +54,6 @@ public class Fill extends DynamicCustomOp { public Fill(SameDiff sameDiff, SDVariable shape, DataType outputDataType, double value) { super(null,sameDiff, new SDVariable[] {shape}, false); this.value = value; - val shp = shape.getArr(); this.outputDataType = outputDataType; addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InTopK.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InTopK.java index 7a69306ab..28fe7c305 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InTopK.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InTopK.java @@ -74,7 +74,7 @@ public class InTopK extends DynamicCustomOp { } Preconditions.checkState(kNode != null, "Could not find 'k' parameter node for op: %s", thisName); - INDArray arr = TFGraphMapper.getInstance().getNDArrayFromTensor(inputName, kNode, graph); + INDArray arr = TFGraphMapper.getNDArrayFromTensor(kNode); this.k = arr.getInt(0); addIArgument(k); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MirrorPad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MirrorPad.java index 1e84fa3f2..bed056888 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MirrorPad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MirrorPad.java @@ -43,7 +43,7 @@ public class MirrorPad extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); iArguments.add(isSymmetric ? 1L : 0L); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ParallelConcat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ParallelConcat.java index e54a9dc40..3d167ea9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ParallelConcat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ParallelConcat.java @@ -42,7 +42,7 @@ public class ParallelConcat extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); // We might want to import everything here? i.e. shape in advance? } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java index 0906451cc..078af1088 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java @@ -75,7 +75,7 @@ public class ReverseSequence extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArguments(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/TopK.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/TopK.java index 0779fd693..e9d40264f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/TopK.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/TopK.java @@ -82,7 +82,7 @@ public class TopK extends DynamicCustomOp { if (kNode != null) { Preconditions.checkState(kNode != null, "Could not find 'k' parameter node for op: %s", thisName); - INDArray arr = TFGraphMapper.getInstance().getNDArrayFromTensor(inputName, kNode, graph); + INDArray arr = TFGraphMapper.getNDArrayFromTensor(kNode); this.k = arr.getInt(0); addIArgument(ArrayUtil.fromBoolean(sorted), k); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java index 98a479542..eb8b820ef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java @@ -84,7 +84,7 @@ public class Cast extends BaseDynamicTransformOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 98f92aaaa..08a546faa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -5708,15 +5708,22 @@ public class Nd4j { for (int e = 0; e < shapeInfo.length; e++) shapeInfo[e] = array.shape(e); - if (Shape.isEmpty(shapeInfo)) - return Nd4j.empty(); + val shapeOf = Shape.shapeOf(shapeInfo); + DataType _dtype = FlatBuffersMapper.getDataTypeFromByte(dtype); + if (Shape.isEmpty(shapeInfo)) { + if(Shape.rank(shapeInfo) == 0) { + return Nd4j.empty(); + } else { + return Nd4j.create(_dtype, shapeOf); + } + } char ordering = shapeInfo[shapeInfo.length - 1] == 99 ? 'c' : 'f'; - val shapeOf = Shape.shapeOf(shapeInfo); + val stridesOf = Shape.stridesOf(shapeInfo); - val _dtype = FlatBuffersMapper.getDataTypeFromByte(dtype); + val _order = FlatBuffersMapper.getOrderFromByte(order); val prod = rank > 0 ? ArrayUtil.prod(shapeOf) : 1; @@ -5809,6 +5816,18 @@ public class Nd4j { b.put(e, sb.get(e)); return Nd4j.create(b, shapeOf); + case BFLOAT16: + case UINT16: + INDArray arr = Nd4j.createUninitialized(_dtype, shapeOf); + ByteBuffer obb = bb.order(_order); + int pos = obb.position(); + byte[] bArr = new byte[obb.limit() - pos]; + + for (int e = 0; e < bArr.length; e++) { + bArr[e] = obb.get(e + pos); + } + arr.data().asNio().put(bArr); + return arr; default: throw new UnsupportedOperationException("Unknown datatype: [" + _dtype + "]"); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/autodiff/execution/NativeGraphExecutioner.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/autodiff/execution/NativeGraphExecutioner.java index 8253b67bb..1abd7a3de 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/autodiff/execution/NativeGraphExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/autodiff/execution/NativeGraphExecutioner.java @@ -27,6 +27,7 @@ import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.graph.FlatArray; import org.nd4j.graph.FlatResult; import org.nd4j.graph.FlatVariable; @@ -115,15 +116,17 @@ public class NativeGraphExecutioner implements GraphExecutioner { for (int e = 0; e < fr.variablesLength(); e++) { FlatVariable var = fr.variables(e); + String varName = var.name(); // log.info("Var received: id: [{}:{}/<{}>];", var.id().first(), var.id().second(), var.name()); FlatArray ndarray = var.ndarray(); - INDArray val = Nd4j.createFromFlatArray(ndarray); results[e] = val; if (var.name() != null && sd.variableMap().containsKey(var.name())) { - sd.associateArrayWithVariable(val, sd.variableMap().get(var.name())); + if(sd.getVariable(varName).getVariableType() != VariableType.ARRAY){ + sd.associateArrayWithVariable(val, sd.variableMap().get(var.name())); + } } else { if (sd.variableMap().get(var.name()) != null) { sd.associateArrayWithVariable(val,sd.getVariable(var.name())); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java index a328d788e..a1746134c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java @@ -120,7 +120,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { log.warn("Warning: Initializing ND4J with " + binLevel + " binary on a CPU with " + optLevel + " support"); log.warn("Using ND4J with " + optLevel + " will improve performance. See deeplearning4j.org/cpu for more details"); log.warn("Or set environment variable " + ND4JEnvironmentVars.ND4J_IGNORE_AVX + "=true to suppress this warning"); - log.warn("************************************************************************************************"); + log.warn("*************************************************************************************************"); } blas = new CpuBlas(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java index 15ed777d4..24a05d73e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java @@ -17,10 +17,13 @@ package org.nd4j.autodiff; import org.junit.Test; +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.AbstractSession; import org.nd4j.autodiff.samediff.internal.InferenceSession; +import org.nd4j.autodiff.samediff.internal.memory.NoOpMemoryMgr; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; @@ -73,7 +76,7 @@ public class TestSessions extends BaseNd4jTest { m.put("y", y); Map outMap = is.output(Collections.singletonList("out"), m, null, - Collections.emptyList(), true, null); + Collections.emptyList(), null, At.defaultAt(Operation.TRAINING)); assertEquals(1, outMap.size()); assertEquals(outExp, outMap.get("out")); @@ -111,7 +114,7 @@ public class TestSessions extends BaseNd4jTest { System.out.println("----------------------------------"); Map outMap = is.output(Collections.singletonList("d"), m, null, - Collections.emptyList(), false, null); + Collections.emptyList(), null, At.defaultAt(Operation.TRAINING)); assertEquals(1, outMap.size()); assertEquals(dExp, outMap.get("d")); @@ -146,7 +149,7 @@ public class TestSessions extends BaseNd4jTest { // String outName = merge.getVarName(); String outName = outVar.getVarName(); Map outMap = is.output(Collections.singletonList(outName), m, null, - Collections.emptyList(), false, null); + Collections.emptyList(), null, At.defaultAt(Operation.TRAINING)); assertEquals(1, outMap.size()); INDArray out = outMap.get(outName); @@ -182,7 +185,7 @@ public class TestSessions extends BaseNd4jTest { System.out.println("----------------------------------"); Map outMap = is.output(Collections.singletonList(n), m, null, Collections.emptyList(), - false, null); + null, At.defaultAt(Operation.TRAINING)); assertEquals(1, outMap.size()); assertEquals(expTrue, outMap.get(n)); @@ -191,12 +194,12 @@ public class TestSessions extends BaseNd4jTest { //Check false case: bArr.assign(0); is = new InferenceSession(sd); - outMap = is.output(Collections.singletonList(n), m, null, Collections.emptyList(), false, null); + outMap = is.output(Collections.singletonList(n), m, null, Collections.emptyList(), null, At.defaultAt(Operation.TRAINING)); assertEquals(1, outMap.size()); assertEquals(expFalse, outMap.get(n)); } - @Test(timeout = 60000L) + @Test(timeout = 20000L) public void testSwitchWhile() throws Exception{ /* @@ -212,18 +215,19 @@ public class TestSessions extends BaseNd4jTest { for( int numIter : new int[]{1,3}) { File f = new ClassPathResource("tf_graphs/examples/while1/iter_" + numIter + "/frozen_model.pb").getFile(); - SameDiff sd = TFGraphMapper.getInstance().importGraph(f); + SameDiff sd = TFGraphMapper.importGraph(f); System.out.println(sd.summary()); System.out.println("----------------------------------"); //This particular test/graph doesn't use placeholders InferenceSession is = new InferenceSession(sd); + is.setMmgr(new NoOpMemoryMgr()); //So arrays aren't deallocated during execution String n = "while/Exit"; String n2 = "while/Exit_1"; Map m = is.output(Arrays.asList(n, n2), Collections.emptyMap(), null, - Collections.emptyList(), false, null); + Collections.emptyList(), null, At.defaultAt(Operation.TRAINING)); assertEquals(2, m.size()); INDArray exp = Nd4j.scalar((float)numIter); @@ -231,7 +235,6 @@ public class TestSessions extends BaseNd4jTest { assertEquals(exp, m.get(n)); assertEquals(exp, m.get(n2)); - Map frameParents = is.getFrameParents(); Map outputs = is.getNodeOutputs(); //Some sanity checks on the internal state: //Check 1: "while/Less" should be executed numIter+1 times... i.e., numIter times through the loop, plus once to exit diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/internal/TestDependencyTracker.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/internal/TestDependencyTracker.java new file mode 100644 index 000000000..f9bd75d4a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/internal/TestDependencyTracker.java @@ -0,0 +1,166 @@ +package org.nd4j.autodiff.internal; + +import org.junit.Test; +import org.nd4j.autodiff.samediff.internal.DependencyList; +import org.nd4j.autodiff.samediff.internal.DependencyTracker; +import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import java.util.Collections; + +import static junit.framework.TestCase.assertNotNull; +import static org.junit.Assert.*; + +public class TestDependencyTracker { + + @Test + public void testSimple(){ + + DependencyTracker dt = new DependencyTracker<>(); + + dt.addDependency("y", "x"); + assertTrue(dt.hasDependency("y")); + assertFalse(dt.hasDependency("x")); + assertFalse(dt.hasDependency("z")); + + DependencyList dl = dt.getDependencies("y"); + assertEquals("y", dl.getDependencyFor()); + assertNotNull(dl.getDependencies()); + assertNull(dl.getOrDependencies()); + assertEquals(Collections.singletonList("x"), dl.getDependencies()); + + dt.removeDependency("y", "x"); + assertFalse(dt.hasDependency("y")); + assertFalse(dt.hasDependency("x")); + dl = dt.getDependencies("y"); + assertTrue(dl.getDependencies() == null || dl.getDependencies().isEmpty()); + assertTrue(dl.getOrDependencies() == null || dl.getOrDependencies().isEmpty()); + + + //Or dep + dt.addOrDependency("y", "x1", "x2"); + assertTrue(dt.hasDependency("y")); + dl = dt.getDependencies("y"); + assertTrue(dl.getDependencies() == null || dl.getDependencies().isEmpty()); + assertTrue(dl.getOrDependencies() != null && !dl.getOrDependencies().isEmpty()); + assertEquals(Collections.singletonList(new Pair<>("x1", "x2")), dl.getOrDependencies()); + + dt.removeDependency("y", "x1"); + assertFalse(dt.hasDependency("y")); + dl = dt.getDependencies("y"); + assertTrue(dl.getDependencies() == null || dl.getDependencies().isEmpty()); + assertTrue(dl.getOrDependencies() == null || dl.getOrDependencies().isEmpty()); + + dt.addOrDependency("y", "x1", "x2"); + dl = dt.getDependencies("y"); + assertTrue(dl.getDependencies() == null || dl.getDependencies().isEmpty()); + assertTrue(dl.getOrDependencies() != null && !dl.getOrDependencies().isEmpty()); + assertEquals(Collections.singletonList(new Pair<>("x1", "x2")), dl.getOrDependencies()); + dt.removeDependency("y", "x2"); + assertTrue(dt.isEmpty()); + } + + @Test + public void testSatisfiedBeforeAdd(){ + DependencyTracker dt = new DependencyTracker<>(); + + //Check different order of adding dependencies: i.e., mark X as satisfied, then add x -> y dependency + // and check that y is added to satisfied list... + dt.markSatisfied("x", true); + dt.addDependency("y", "x"); + assertTrue(dt.hasNewAllSatisfied()); + assertEquals("y", dt.getNewAllSatisfied()); + + //Same as above - x satisfied, add x->y, then add z->y + //y should go from satisfied to not satisfied + dt.clear(); + assertTrue(dt.isEmpty()); + dt.markSatisfied("x", true); + dt.addDependency("y", "x"); + assertTrue(dt.hasNewAllSatisfied()); + dt.addDependency("y", "z"); + assertFalse(dt.hasNewAllSatisfied()); + + + //x satisfied, then or(x,y) -> z added + dt.markSatisfied("x", true); + dt.addOrDependency("z", "x", "y"); + assertTrue(dt.hasNewAllSatisfied()); + assertEquals("z", dt.getNewAllSatisfied()); + + + //x satisfied, then or(x,y) -> z added, then or(a,b)->z added (should be unsatisfied) + dt.clear(); + assertTrue(dt.isEmpty()); + dt.markSatisfied("x", true); + dt.addOrDependency("z", "x", "y"); + assertTrue(dt.hasNewAllSatisfied()); + dt.addOrDependency("z", "a", "b"); + assertFalse(dt.hasNewAllSatisfied()); + } + + @Test + public void testMarkUnsatisfied(){ + + DependencyTracker dt = new DependencyTracker<>(); + dt.addDependency("y", "x"); + dt.markSatisfied("x", true); + assertTrue(dt.hasNewAllSatisfied()); + + dt.markSatisfied("x", false); + assertFalse(dt.hasNewAllSatisfied()); + dt.markSatisfied("x", true); + assertTrue(dt.hasNewAllSatisfied()); + assertEquals("y", dt.getNewAllSatisfied()); + assertFalse(dt.hasNewAllSatisfied()); + + + //Same for OR dependencies + dt.clear(); + assertTrue(dt.isEmpty()); + dt.addOrDependency("z", "x", "y"); + dt.markSatisfied("x", true); + assertTrue(dt.hasNewAllSatisfied()); + + dt.markSatisfied("x", false); + assertFalse(dt.hasNewAllSatisfied()); + dt.markSatisfied("x", true); + assertTrue(dt.hasNewAllSatisfied()); + assertEquals("z", dt.getNewAllSatisfied()); + assertFalse(dt.hasNewAllSatisfied()); + } + + + @Test + public void testIdentityDependencyTracker(){ + IdentityDependencyTracker dt = new IdentityDependencyTracker<>(); + assertTrue(dt.isEmpty()); + + INDArray y1 = Nd4j.scalar(0); + INDArray y2 = Nd4j.scalar(0); + String x1 = "x1"; + dt.addDependency(y1, x1); + + assertFalse(dt.hasNewAllSatisfied()); + assertTrue(dt.hasDependency(y1)); + assertFalse(dt.hasDependency(y2)); + assertFalse(dt.isSatisfied(x1)); + + DependencyList dl = dt.getDependencies(y1); + assertSame(y1, dl.getDependencyFor()); //Should be same object + assertEquals(Collections.singletonList(x1), dl.getDependencies()); + assertNull(dl.getOrDependencies()); + + + //Mark as satisfied, check if it's added to list + dt.markSatisfied(x1, true); + assertTrue(dt.isSatisfied(x1)); + assertTrue(dt.hasNewAllSatisfied()); + INDArray get = dt.getNewAllSatisfied(); + assertSame(y1, get); + assertFalse(dt.hasNewAllSatisfied()); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 539901a41..7939bcdae 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -123,27 +123,24 @@ public class LayerOpValidation extends BaseOpValidation { public void testBiasAdd() { Nd4j.getRandom().setSeed(12345); - for (boolean rank1Bias : new boolean[]{false, true}) { + SameDiff sameDiff = SameDiff.create(); + INDArray input = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(new long[]{2, 4}); + INDArray b = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).divi(4); - SameDiff sameDiff = SameDiff.create(); - INDArray input = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(new long[]{2, 4}); - INDArray b = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(rank1Bias ? new long[]{4} : new long[]{1, 4}).divi(4); + SDVariable sdInput = sameDiff.var("input", input); + SDVariable sdBias = sameDiff.var("bias", b); - SDVariable sdInput = sameDiff.var("input", input); - SDVariable sdBias = sameDiff.var("bias", b); + SDVariable res = sameDiff.nn().biasAdd(sdInput, sdBias, true); + SDVariable loss = sameDiff.standardDeviation(res, true); - SDVariable res = sameDiff.nn().biasAdd(sdInput, sdBias); - SDVariable loss = sameDiff.standardDeviation(res, true); + INDArray exp = input.addRowVector(b); - INDArray exp = input.addRowVector(b); + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(res.getVarName(), exp); - TestCase tc = new TestCase(sameDiff) - .gradientCheck(true) - .expectedOutput(res.getVarName(), exp); - - String err = OpValidation.validate(tc); - assertNull(err); - } + String err = OpValidation.validate(tc); + assertNull(err); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index eb228bf1f..37c1a7086 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -586,7 +586,8 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable varMul = varMulPre.mul("d", sdVariable1); SDVariable sum = sameDiff.sum("ret", varMul, Integer.MAX_VALUE); - sameDiff.execBackwards(Collections.emptyMap()); + Map m = sameDiff.outputAll(null); + Map gm = sameDiff.calculateGradients(null, m.keySet()); SDVariable finalResult = sameDiff.grad(sum.getVarName()); @@ -597,7 +598,7 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable wGrad = sameDiff.grad(sdVariable1.getVarName()); SDVariable dGrad = sameDiff.grad(varMul.getVarName()); - INDArray scalarGradTest = finalResult.getArr(); + INDArray scalarGradTest = gm.get(sum.getVarName()); assertEquals(scalar, scalarGradTest); @@ -1265,17 +1266,17 @@ public class MiscOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5)); - SDVariable merged = sd.math().mergeAvg(var); + SDVariable merged = sd.math().mergeAvg("merged", var); SDVariable sum = sd.sum(merged); - sd.execAndEndResult(); - sd.execBackwards(Collections.emptyMap()); + Map m = sd.output(Collections.emptyMap(), "merged"); + Map gm = sd.calculateGradients(null, "in"); - INDArray out = merged.getArr(); + INDArray out = m.get("merged"); assertEquals(1, out.rank()); - INDArray inGrad = var.getGradient().getArr(); - assertEquals(1, inGrad.rank()); //Fails here, getting rank 2 + INDArray inGrad = gm.get("in"); + assertEquals(1, inGrad.rank()); } @Test @@ -1643,10 +1644,10 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable v = new StopGradient(sd, w).outputVariable(); SDVariable loss = v.std(true); - sd.execBackwards(null); + Map gm = sd.calculateGradients(null, v.getVarName(), w.getVarName()); - INDArray vArr = v.getGradient().getArr(); - INDArray wArr = w.getGradient().getArr(); + INDArray vArr = gm.get(v.getVarName()); + INDArray wArr = gm.get(w.getVarName()); System.out.println(vArr); System.out.println(wArr); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 2cbdfd3fd..101dcdeaf 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -244,10 +244,10 @@ public class ShapeOpValidation extends BaseOpValidation { //Using stdev here: mean/sum would backprop the same gradient for each input... SDVariable stdev = sd.standardDeviation("out", expand, true); - INDArray out = sd.execAndEndResult(); + Map m = sd.outputAll(null); INDArray expOut = in.getArr().std(true); - assertArrayEquals(expExpandShape, expand.getArr().shape()); + assertArrayEquals(expExpandShape, m.get(expand.getVarName()).shape()); INDArray expExpand = inArr.dup('c').reshape(expExpandShape); String msg = "expandDim=" + i + ", source=" + p.getSecond(); @@ -304,9 +304,9 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray exp = inArr.dup('c').reshape('c', expShapePostSqueeze); - sd.execAndEndResult(); + Map m = sd.outputAll(null); - INDArray squeezed = squeeze.getArr(); + INDArray squeezed = m.get(squeeze.getVarName()); // assertArrayEquals(expShapePostSqueeze, squeezed.shape()); INDArray out = sd.execAndEndResult(); @@ -546,7 +546,7 @@ public class ShapeOpValidation extends BaseOpValidation { .testName(msg); String error = OpValidation.validate(tc, true); if(error != null){ - failed.add(msg); + failed.add(msg + " - " + error); } } } @@ -712,9 +712,9 @@ public class ShapeOpValidation extends BaseOpValidation { String msg = "Unstacked shape = " + Arrays.toString(shape) + ", stacked shape = " + Arrays.toString(stackedShape) + ", axis=" + axis + ", numInputs=" + numInputs; - sd.execAndEndResult(); + Map m = sd.outputAll(null); for (SDVariable v : unstacked) { - assertArrayEquals(msg, shape, v.getArr().shape()); + assertArrayEquals(msg, shape, m.get(v.getVarName()).shape()); } TestCase tc = new TestCase(sd).testName(msg); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java index bec6e0349..308ef0fd4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java @@ -78,82 +78,6 @@ public class FailingSameDiffTests extends BaseNd4jTest { assertArrayEquals(new long[]{3,3}, list.get(0).getShape()); } - @Test(timeout = 10000L) - public void testWhileLoop() { - OpValidationSuite.ignoreFailing(); - SameDiff sameDiff = SameDiff.create(); - sameDiff.whileStatement(new DefaultSameDiffConditional(), new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable eqResult = sameDiff.neq(variableInputs[0], variableInputs[1]); - return new SDVariable[]{eqResult}; - } - }, new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable ret = variableInputs[1].add(1.0); - return new SDVariable[]{variableInputs[0], ret}; - } - }, new SDVariable[]{ - sameDiff.one("one", new long[]{1, 1}), - sameDiff.var("two", new long[]{1, 1}), - - }); - - sameDiff.exec(Collections.emptyMap()); - } - - @Test(timeout = 10000L) - public void testWhileBackwards() { - OpValidationSuite.ignoreFailing(); - SameDiff sameDiff = SameDiff.create(); - sameDiff.whileStatement(new DefaultSameDiffConditional(), new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable eqResult = sameDiff.neq(variableInputs[0], variableInputs[1]); - return new SDVariable[]{eqResult}; - } - }, new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable ret = variableInputs[1].add(1.0); - return new SDVariable[]{variableInputs[0], ret}; - } - }, new SDVariable[]{ - sameDiff.one("one", new long[]{1, 1}), - sameDiff.var("two", new long[]{1, 1}), - - }); - - sameDiff.execBackwards(Collections.emptyMap()); - SameDiff exec = sameDiff.getFunction("grad"); - } - - @Test(timeout = 10000L) - public void testWhileLoop2() { - OpValidationSuite.ignoreFailing(); - SameDiff sameDiff = SameDiff.create(); - sameDiff.whileStatement(new DefaultSameDiffConditional(), new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable eqResult = sameDiff.neq(variableInputs[0], variableInputs[1]); - return new SDVariable[]{eqResult}; - } - }, new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable ret = variableInputs[1].add(1.0); - return new SDVariable[]{variableInputs[0], ret}; - } - }, new SDVariable[]{ - sameDiff.one("one", new long[]{1, 1}), - sameDiff.var("two", new long[]{1, 1}), - - }); - - sameDiff.exec(Collections.emptyMap(), sameDiff.outputs()); - } - @Test public void testExecutionDifferentShapesTransform(){ OpValidationSuite.ignoreFailing(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java index c291a5556..a3ee34570 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java @@ -317,7 +317,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { } for(SDVariable v : sd.variables()){ - if(v.isPlaceHolder()) + if(v.isPlaceHolder() || v.getVariableType() == VariableType.ARRAY) continue; SDVariable v2 = sd2.getVariable(v.getVarName()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index 7d17b3604..f19a4ec8b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -29,6 +29,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import java.io.IOException; import java.lang.reflect.Field; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -295,8 +296,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable y = sameDiff.var("y", arr); SDVariable result = sameDiff.mmul(x, y); SDVariable otherResult = result.add(result); - sameDiff.exec(Collections.emptyMap(), sameDiff.outputs()); - assertArrayEquals(new long[]{2, 2}, result.getShape()); + Map m = sameDiff.outputAll(null); + assertArrayEquals(new long[]{2, 2}, m.get(result.getVarName()).shape()); } @@ -544,143 +545,6 @@ public class SameDiffTests extends BaseNd4jTest { } - - @Test - public void testIfStatementTrueBodyBackwards() { - OpValidationSuite - .ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations - SameDiff sameDiff = SameDiff.create(); - SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable sum = sameDiff.sum(variableInputs[0], Integer.MAX_VALUE); - SDVariable result = sameDiff.gt(sum, 1.0); - return new SDVariable[]{result}; - } - }; - - SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable add = variableInputs[0].add(1.0); - return new SDVariable[]{add}; - } - }; - - SameDiffFunctionDefinition falseBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable sub = variableInputs[0].sub(1.0); - return new SDVariable[]{sub}; - } - }; - - //true body trigger - SDVariable[] firstInputs = new SDVariable[]{ - sameDiff.var("one", new long[]{1, 1}) - - }; - - sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, firstInputs); - sameDiff.execBackwards(Collections.emptyMap()); - SameDiff grad = sameDiff.getFunction("grad"); - /* If ifBlock = (If) grad.getFunction(new long[]{1},new long[]{2}); - SameDiff assertComparision = SameDiff.create(); - SDVariable initialInput = assertComparision.zero("zero",new long[]{1,1}); - initialInput.addi(1.0); - assumeNotNull(ifBlock.getTrueBodyExecuted()); - assertTrue(ifBlock.getTrueBodyExecuted()); - assertEquals(Nd4j.scalar(1.00),initialInput.getArr()); - assertEquals(Nd4j.scalar(1.0),ifBlock.getLoopBodyExecution().getVariableForVertexId(2).getArr()); -*/ - } - - - @Test - public void testIfStatementTrueBody() { - OpValidationSuite - .ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations - SameDiff sameDiff = SameDiff.create(); - - SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable sum = sameDiff.sum(variableInputs[0], Integer.MAX_VALUE); - SDVariable result = sameDiff.gt(sum, 1.0); - return new SDVariable[]{result}; - } - }; - - SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable add = variableInputs[0].add(1.0); - return new SDVariable[]{add}; - } - }; - - SameDiffFunctionDefinition falseBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable sub = variableInputs[0].sub(1.0); - return new SDVariable[]{sub}; - } - }; - - //true body trigger - SDVariable[] firstInputs = new SDVariable[]{ - sameDiff.var("one", new long[]{1, 1}) - - }; - - sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, firstInputs); - sameDiff.exec(Collections.emptyMap()); - } - - - @Test - public void testIfStatementFalseBody() { - OpValidationSuite - .ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations - SameDiff sameDiff = SameDiff.create(); - - SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable sum = sameDiff.sum(variableInputs[0], Integer.MAX_VALUE); - SDVariable result = sameDiff.gt(sum, 1.0); - return new SDVariable[]{result}; - } - }; - - SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable add = variableInputs[0].add(1.0); - return new SDVariable[]{add}; - } - }; - - SameDiffFunctionDefinition falseBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable sub = variableInputs[0].sub(1.0); - return new SDVariable[]{sub}; - } - }; - - //false body trigger - SDVariable[] secondInputs = new SDVariable[]{ - sameDiff.setupFunction(sameDiff.var("two", new long[]{1, 1})) - - }; - - sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, secondInputs); - - sameDiff.exec(Collections.emptyMap()); - } - - @Test public void testAutoBroadcastAddMatrixVector() { SameDiff sameDiff = SameDiff.create(); @@ -813,10 +677,10 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in = sd.var("in", new long[]{10, 9, 8}); SDVariable mean1 = sd.mean(in, 2); //[10,9] out SDVariable mean2 = sd.mean(mean1, 1); //[10] out - sd.execAndEndResult(); + Map m = sd.output((Map)null, mean1.getVarName(), mean2.getVarName()); - INDArray m1 = mean1.getArr(); - INDArray m2 = mean2.getArr(); + INDArray m1 = m.get(mean1.getVarName()); + INDArray m2 = m.get(mean2.getVarName()); assertArrayEquals(new long[]{10, 9}, m1.shape()); assertArrayEquals(new long[]{10}, m2.shape()); @@ -833,16 +697,16 @@ public class SameDiffTests extends BaseNd4jTest { assertArrayEquals(new long[]{9, 8}, meanA.getShape()); SDVariable meanB = sd2.mean(meanA, 0); //[8] out - sd2.exec(null, sd2.outputs()); - assertArrayEquals(new long[]{8}, meanB.getShape()); + Map m = sd2.outputAll(null); + assertArrayEquals(new long[]{8}, m.get(meanB.getVarName()).shape()); - assertArrayEquals(meanA.getShape(), meanA.getArr().shape()); - assertArrayEquals(meanB.getShape(), meanB.getArr().shape()); + assertArrayEquals(meanA.getShape(), m.get(meanA.getVarName()).shape()); + assertArrayEquals(meanB.getShape(), m.get(meanB.getVarName()).shape()); - sd2.exec(Collections.emptyMap(), sd2.outputs()); + m = sd2.outputAll(null); - INDArray mA = meanA.getArr(); - INDArray mB = meanB.getArr(); + INDArray mA = m.get(meanA.getVarName()); + INDArray mB = m.get(meanB.getVarName()); assertArrayEquals(new long[]{9, 8}, mA.shape()); assertArrayEquals(new long[]{8}, mB.shape()); @@ -858,10 +722,10 @@ public class SameDiffTests extends BaseNd4jTest { val f = m.add(2.0); val s = in2.add(5.0); - val arr = sd.execSingle(null, s.getVarName()); - log.info("Result M: {}", m.getArr()); - log.info("Result F: {}", f.getArr()); - log.info("Result S: {}", s.getArr()); + Map map = sd.outputAll(null); + log.info("Result M: {}", map.get(m.getVarName())); + log.info("Result F: {}", map.get(f.getVarName())); + log.info("Result S: {}", map.get(s.getVarName())); } @Test @@ -1097,11 +961,11 @@ public class SameDiffTests extends BaseNd4jTest { INDArray expZ = expMmul.addRowVector(iBias); INDArray expOut = Transforms.sigmoid(expZ, true); - sd.exec(Collections.emptyMap(), sd.outputs()); + Map m = sd.outputAll(Collections.emptyMap()); - assertEquals(expMmul, mmul.getArr()); - assertEquals(expZ, z.getArr()); - assertEquals(expOut, out.getArr()); + assertEquals(expMmul, m.get(mmul.getVarName())); + assertEquals(expZ, m.get(z.getVarName())); + assertEquals(expOut, m.get(out.getVarName())); } @Test @@ -1178,8 +1042,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable sqDiff = diff.mul("sqDiff", diff); SDVariable totSum = sd.sum("totSum", sqDiff, Integer.MAX_VALUE); //Loss function... - sd.exec(Collections.emptyMap(), sd.outputs()); - INDArray outAct = sd.getVariable("out").getArr(); + Map m = sd.output(Collections.emptyMap(), "out"); + INDArray outAct = m.get("out"); assertEquals(a.toString(), outExp, outAct); // L = sum_i (label - out)^2 @@ -1187,10 +1051,11 @@ public class SameDiffTests extends BaseNd4jTest { INDArray dLdOutExp = outExp.sub(labelArr).mul(2); INDArray dLdInExp = a.getActivationFunction().backprop(inArr.dup(), dLdOutExp.dup()).getFirst(); - sd.execBackwards(Collections.emptyMap()); - SameDiff gradFn = sd.getFunction("grad"); - INDArray dLdOutAct = gradFn.getVariable("out-grad").getArr(); - INDArray dLdInAct = gradFn.getVariable("in-grad").getArr(); + Map grads = sd.calculateGradients(null, "out", "in"); +// sd.execBackwards(Collections.emptyMap()); +// SameDiff gradFn = sd.getFunction("grad"); + INDArray dLdOutAct = grads.get("out"); + INDArray dLdInAct = grads.get("in"); assertEquals(a.toString(), dLdOutExp, dLdOutAct); assertEquals(a.toString(), dLdInExp, dLdInAct); @@ -1256,10 +1121,10 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable out = sd.cnn().localResponseNormalization(sdInput, lrn); SDVariable sdOut = sd.math().tanh("out", out); - sd.exec(Collections.emptyMap(), sd.outputs()); + Map map = sd.output(Collections.emptyMap(), "out", out.getVarName()); for (int i = 0; i < 4; i++) { - assertEquals(1, out.getArr().get(all(), NDArrayIndex.point(i), all(), all()).getInt(0)); + assertEquals(1, map.get(out.getVarName()).get(all(), NDArrayIndex.point(i), all(), all()).getInt(0)); } } @@ -1280,10 +1145,10 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable sum = mean.add(variance); SDVariable out = sd.math().tanh("out", sum); - INDArray outArr = sd.execAndEndResult(); + Map m = sd.outputAll(null); - INDArray meanArray = mean.getArr(); - INDArray varArray = variance.getArr(); + INDArray meanArray = m.get(mean.getVarName()); + INDArray varArray = m.get(variance.getVarName()); assertEquals(meanArray.getDouble(0), 2.5, 1e-5); assertEquals(varArray.getDouble(0), 1.25, 1e-5); @@ -1309,15 +1174,14 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable sum = normMean.add(normVariance); SDVariable out = sd.math().tanh("out", sum); - INDArray outArr = sd.execAndEndResult(); + Map m = sd.outputAll(null); - INDArray meanArray = normMean.getArr(); - INDArray varArray = normVariance.getArr(); + INDArray meanArray = m.get(normMean.getVarName()); + INDArray varArray = m.get(normVariance.getVarName()); assertEquals(meanArray.getDouble(0, 0), 1, 1e-5); assertEquals(meanArray.getDouble(0, 1), 2, 1e-5); assertArrayEquals(meanArray.shape(), varArray.shape()); - } @@ -2469,72 +2333,25 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable loss = out.std("out", true); INDArray outArr = sd.execAndEndResult().dup(); - sd.execBackwards(Collections.emptyMap()); +// sd.execBackwards(Collections.emptyMap()); + Map grads = sd.calculateGradients(null, in.getVarName(), w.getVarName(), out.getVarName()); Map origGrad = new HashMap<>(); - origGrad.put("in", in.gradient().getArr().dup()); - origGrad.put("w", w.gradient().getArr().dup()); - origGrad.put("out", out.gradient().getArr().dup()); + origGrad.put("in", grads.get(in.getVarName()).dup()); + origGrad.put("w", grads.get(w.getVarName()).dup()); + origGrad.put("out", grads.get(out.getVarName()).dup()); in.getArr().assign(Nd4j.rand(in.getArr().shape())); INDArray outArr2 = sd.execAndEndResult(); - sd.execBackwards(Collections.emptyMap()); +// sd.execBackwards(Collections.emptyMap()); + grads = sd.calculateGradients(null, in.getVarName(), w.getVarName(), out.getVarName()); assertNotEquals(outArr, outArr2); //Ensure gradients are also changed: - assertNotEquals(origGrad.get("in"), in.gradient().getArr()); - assertNotEquals(origGrad.get("w"), w.gradient().getArr()); - assertNotEquals(origGrad.get("out"), out.gradient().getArr()); - } - - @Test - public void testUpdatingInplaceFwd() { - SameDiff sd = SameDiff.create(); - SDVariable in = sd.var("in", Nd4j.linspace(1, 12, 12).reshape(3, 4)); - SDVariable w = sd.var("w", Nd4j.linspace(1, 20, 20).reshape(4, 5)); - SDVariable out = sd.mmul(in, w); - SDVariable loss = out.std("out", true); - - INDArray outArr = sd.execAndEndResult().dup(); - sd.execBackwards(Collections.emptyMap()); - - Map origGrad = new HashMap<>(); - origGrad.put("in", in.gradient().getArr().dup()); - origGrad.put("w", w.gradient().getArr().dup()); - origGrad.put("out", out.gradient().getArr().dup()); - - in.getArr().muli(5); - - //check gradient function copy of array - SameDiff sdGrad = sd.getFunction("grad"); - INDArray gradArrIn = sdGrad.getVariable("in").getArr(); - assertEquals(in.getArr(), gradArrIn); - } - - @Test - public void testUpdatingAssociateFwd() { - SameDiff sd = SameDiff.create(); - SDVariable in = sd.var("in", Nd4j.linspace(1, 12, 12).reshape(3, 4)); - SDVariable w = sd.var("w", Nd4j.linspace(1, 20, 20).reshape(4, 5)); - SDVariable out = sd.mmul(in, w); - SDVariable loss = out.std("out", true); - - INDArray outArr = sd.execAndEndResult().dup(); - sd.execBackwards(Collections.emptyMap()); - - Map origGrad = new HashMap<>(); - origGrad.put("in", in.gradient().getArr().dup()); - origGrad.put("w", w.gradient().getArr().dup()); - origGrad.put("out", out.gradient().getArr().dup()); - - INDArray newIn = in.getArr().dup().muli(5); - in.setArray(newIn); - - //check gradient function copy of array - SameDiff sdGrad = sd.getFunction("grad"); - INDArray gradArrIn = sdGrad.getVariable("in").getArr(); - assertEquals(newIn, gradArrIn); + assertNotEquals(origGrad.get("in"), grads.get(in.getVarName())); + assertNotEquals(origGrad.get("w"), grads.get(w.getVarName())); + assertNotEquals(origGrad.get("out"), grads.get(out.getVarName())); } @Test @@ -2545,26 +2362,24 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable loss = out.std("out", true); INDArray outArr = sd.execAndEndResult().dup(); - sd.execBackwards(Collections.emptyMap()); - - SameDiff sdGrad = sd.getFunction("grad"); + Map grads = sd.calculateGradients(null, in.getVarName(), out.getVarName()); Map origGrad = new HashMap<>(); - origGrad.put("in", in.gradient().getArr().dup()); - origGrad.put("out", out.gradient().getArr().dup()); + origGrad.put("in", grads.get(in.getVarName()).dup()); + origGrad.put("out", grads.get(out.getVarName()).dup()); double stdBefore = in.getArr().stdNumber().doubleValue(); in.getArr().assign(Nd4j.rand(in.getArr().shape())); double stdAfter = in.getArr().stdNumber().doubleValue(); System.out.println("Before vs. after: " + stdBefore + ", " + stdAfter); INDArray outArr2 = sd.execAndEndResult(); - sd.execBackwards(Collections.emptyMap()); + grads = sd.calculateGradients(null, in.getVarName(), out.getVarName()); assertNotEquals(outArr, outArr2); //Ensure gradients are also changed: - assertNotEquals(origGrad.get("in"), in.gradient().getArr()); - assertNotEquals(origGrad.get("out"), out.gradient().getArr()); + assertNotEquals(origGrad.get("in"), grads.get(in.getVarName())); + assertNotEquals(origGrad.get("out"), grads.get(out.getVarName())); } @Test @@ -2589,10 +2404,10 @@ public class SameDiffTests extends BaseNd4jTest { Map phMap = new HashMap<>(); phMap.put(fn.getGradPlaceholderName(), grad); - log.info("--------------- sd.execAndEndResult() ---------------"); - sd.execAndEndResult(); + log.info("--------------- out.eval() ---------------"); + out.eval(); log.info("--------------- sd.execBackwards() #1 ---------------"); - sd.execBackwards(phMap); + sd.calculateGradients(phMap, "in", "W", "b"); log.info("--------------- sd.execBackwards() #2 ---------------"); System.out.println(sd.getFunction("grad").summary()); @@ -2658,9 +2473,8 @@ public class SameDiffTests extends BaseNd4jTest { Map placeholders = new HashMap<>(); placeholders.put("x", x); placeholders.put("y", y); - sd.createGradFunction(); //Otherwise: xSd.gradient() etc won't be defined - sd.execBackwards(placeholders, Arrays.asList(xSd.gradient().getVarName(), ySd.gradient().getVarName())); - INDArray xGradientEnforced = add.getGradient().getArr(true); + Map grads = sd.calculateGradients(placeholders, xSd.getVarName(), ySd.getVarName()); + INDArray xGradientEnforced = grads.get("x"); assertNotNull(xGradientEnforced); } @@ -2778,7 +2592,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray out2 = tanh.eval(); - assertEquals(out, out2); + assertNotEquals(out, out2); assertEquals(VariableType.VARIABLE, w.getVariableType()); assertEquals(VariableType.VARIABLE, b.getVariableType()); assertEquals(VariableType.ARRAY, add.getVariableType()); @@ -3133,6 +2947,7 @@ public class SameDiffTests extends BaseNd4jTest { @Test public void testPlaceholderShapeValidation() { SameDiff sd = SameDiff.create(); + SDVariable scalar = sd.scalar("scalar", 0.0f); SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4); SDVariable ph2 = sd.placeHolder("ph2", DataType.FLOAT, -1, 4); SDVariable ph3 = sd.placeHolder("ph3", DataType.FLOAT, 3, -1); @@ -3177,7 +2992,7 @@ public class SameDiffTests extends BaseNd4jTest { //Also try training: SDVariable sum = sd.math.mergeAdd(ph1, ph2, ph3, ph4); - SDVariable mean = sum.mean(); + SDVariable mean = sum.add(scalar).mean(); MultiDataSet mds = new MultiDataSet(new INDArray[]{wrongShape, wrongShape, wrongShape, wrongShape}, null); sd.setTrainingConfig(TrainingConfig.builder() @@ -3411,29 +3226,28 @@ public class SameDiffTests extends BaseNd4jTest { @Test public void testIf() throws IOException { - SameDiff SD = SameDiff.create(); - SDVariable a = SD.placeHolder("a", DataType.DOUBLE); - SDVariable b = SD.var("b", Nd4j.createFromArray(5.0)); - SDVariable c = SD.var("c", Nd4j.createFromArray(9.0)); + SameDiff sd = SameDiff.create(); + SDVariable a = sd.placeHolder("a", DataType.DOUBLE); + SDVariable b = sd.var("b", Nd4j.createFromArray(5.0)); + SDVariable c = sd.var("c", Nd4j.createFromArray(9.0)); - SDVariable output = SD.ifCond("out", null, (sd) -> a.lt(b), (sd) -> c, (sd) -> c.add(5)); + SDVariable output = sd.ifCond("out", null, s -> a.lt(b), s -> c, s -> c.add(5)); Map firstBranch = Maps.newHashMap(); firstBranch.put("a", Nd4j.createFromArray(3.0)); - assertEquals(Nd4j.createFromArray(9.0), SD.exec(firstBranch, "out").get("out")); + assertEquals(Nd4j.createFromArray(9.0), sd.output(firstBranch, "out").get("out")); Map secondBranch = Maps.newHashMap(); secondBranch.put("a", Nd4j.createFromArray(7.0)); - assertEquals(Nd4j.createFromArray(14.0), SD.exec(secondBranch, "out").get("out")); - - //TODO complains that it can't deserialize a meta type, but there are no meta type ops here - // looks like a difference between Op.Type and OpType. Switch is saved as a OpType.LOGIC - SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); - - assertEquals(Nd4j.createFromArray(9.0), SD.exec(firstBranch, "out").get("out")); - assertEquals(Nd4j.createFromArray(14.0), SD.exec(secondBranch, "out").get("out")); + System.out.println(sd.summary()); + INDArray outArr = sd.output(secondBranch, "out").get("out"); + assertEquals(Nd4j.createFromArray(14.0), outArr); + ByteBuffer bb = sd.asFlatBuffers(false); + sd = SameDiff.fromFlatBuffers(bb); + assertEquals(Nd4j.createFromArray(9.0), sd.output(firstBranch, "out").get("out")); + assertEquals(Nd4j.createFromArray(14.0), sd.output(secondBranch, "out").get("out")); } @Test @@ -3456,7 +3270,7 @@ public class SameDiffTests extends BaseNd4jTest { SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); - assertEquals(Nd4j.createFromArray(10.0), SD.exec(null, "out").get("out")); + assertEquals(Nd4j.createFromArray(10.0), SD.output(Collections.emptyMap(), "out").get("out")); } @Test @@ -3477,7 +3291,7 @@ public class SameDiffTests extends BaseNd4jTest { SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); - assertEquals(15, SD.exec(null, outName).get(outName).getInt(0)); + assertEquals(15, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); } @Test @@ -3503,7 +3317,7 @@ public class SameDiffTests extends BaseNd4jTest { SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); - assertEquals(35, SD.exec(null, outName).get(outName).getInt(0)); + assertEquals(35, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); } @@ -3529,7 +3343,7 @@ public class SameDiffTests extends BaseNd4jTest { SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); - assertEquals(115, SD.exec(null, outName).get(outName).getInt(0)); + assertEquals(115, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java index 2411af627..417652dcc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java @@ -181,7 +181,8 @@ public class UIListenerTest { SameDiff sd2 = SameDiff.create(); SDVariable in1 = sd2.placeHolder("in1", DataType.FLOAT, -1, 4); SDVariable in2 = sd2.placeHolder("in2", DataType.FLOAT, -1, 4); - SDVariable mul = in1.mul(in2); + SDVariable w = sd2.var("w", DataType.FLOAT, 1, 4); + SDVariable mul = in1.mul(in2).mul(w); SDVariable loss = mul.std(true); sd2.setTrainingConfig(TrainingConfig.builder() .dataSetFeatureMapping("in") diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java index 1f6301b2f..ddf4775a8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java @@ -75,7 +75,8 @@ public class ExecutionTests extends BaseNd4jTest { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraphTxt(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream(), null, null); + System.out.println(tg.summary()); Map result_0 = tg.exec(Collections.emptyMap(), tg.outputs()); val exp_0 = Nd4j.create(DataType.FLOAT, 3).assign(3.0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java index 074cef780..c70dfa436 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java @@ -23,7 +23,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.autodiff.samediff.transform.*; -import org.nd4j.base.Preconditions; import org.nd4j.graph.ui.LogFileWriter; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.tensorflow.TFImportOverride; @@ -35,7 +34,6 @@ import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.learning.config.Adam; -import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.resources.Downloader; import org.nd4j.util.ArchiveUtils; @@ -109,7 +107,7 @@ public class BERTGraphTest extends BaseNd4jTest { //Skip the "IteratorV2" op - we don't want or need this TFOpImportFilter filter = (nodeDef, initWith, attributesForNode, graph) -> { return "IteratorV2".equals(nodeDef.getName()); }; - SameDiff sd = TFGraphMapper.getInstance().importGraph(f, m, filter); + SameDiff sd = TFGraphMapper.importGraph(f, m, filter); /* Modify the network to remove hard-coded dropout operations for inference. @@ -317,7 +315,7 @@ public class BERTGraphTest extends BaseNd4jTest { //Skip the "IteratorV2" op - we don't want or need this TFOpImportFilter filter = (nodeDef, initWith, attributesForNode, graph) -> { return "IteratorV2".equals(nodeDef.getName()); }; - SameDiff sd = TFGraphMapper.getInstance().importGraph(f, m, filter); + SameDiff sd = TFGraphMapper.importGraph(f, m, filter); /* Set floatConstants = new HashSet<>(Arrays.asList( @@ -431,7 +429,7 @@ public class BERTGraphTest extends BaseNd4jTest { return "IteratorV2".equals(nodeDef.getName()); }; - SameDiff sd = TFGraphMapper.getInstance().importGraph(f, m, filter); + SameDiff sd = TFGraphMapper.importGraph(f, m, filter); LogFileWriter w = new LogFileWriter(new File("C:/Temp/BERT_UI.bin")); long bytesWritten = w.writeGraphStructure(sd); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index ecc81c981..c57f8c5d9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -30,10 +30,13 @@ import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.Listener; -import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.internal.memory.ArrayCloseMemoryMgr; +import org.nd4j.autodiff.samediff.internal.memory.CloseValidationMemoryMgr; import org.nd4j.autodiff.validation.OpValidation; +import org.nd4j.autodiff.validation.TestCase; import org.nd4j.base.Preconditions; import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; @@ -62,6 +65,7 @@ import org.springframework.core.io.support.ResourcePatternResolver; import java.io.*; import java.net.URI; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.*; import java.util.regex.Pattern; @@ -84,7 +88,7 @@ public class TFGraphTestAllHelper { @Override public SameDiff apply(File file, String name) { try(InputStream is = new BufferedInputStream(new FileInputStream(file))){ - SameDiff sd = TFGraphMapper.getInstance().importGraph(is); + SameDiff sd = TFGraphMapper.importGraph(is); return sd; } catch (IOException e){ throw new RuntimeException(e); @@ -138,7 +142,19 @@ public class TFGraphTestAllHelper { " must be null or both must be provided"); Nd4j.EPS_THRESHOLD = 1e-3; - SameDiff graph = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null); + Set outputsToCheck = new HashSet<>(); + for(String s : predictions.keySet()) { + // we need to convert name from python name format with . on indices, to :. i.e.: output.1 -> output:1 + if (s.matches(".*\\.\\d+")) { + int idx = s.lastIndexOf('.'); + s = s.substring(0, idx) + ":" + s.substring(idx+1); + } + outputsToCheck.add(s); + } + + Pair> p = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null, outputsToCheck); + SameDiff graph = p.getFirst(); + Map sameDiffPredictions = p.getSecond(); //Collect coverage info about ops OpValidation.collectTensorflowImportCoverage(graph); @@ -156,7 +172,7 @@ public class TFGraphTestAllHelper { nd4jNode = outputNode.replaceAll("\\.", ":"); try { - nd4jPred = graph.getVariable(nd4jNode).getArr(); + nd4jPred = sameDiffPredictions.get(nd4jNode); } catch (NullPointerException e) { throw new NullPointerException("Can't find SameDiff variable with name [" + nd4jNode + "]"); } @@ -270,6 +286,12 @@ public class TFGraphTestAllHelper { log.info("\n========================================================\n"); } + //Serialize and deserialize, check equality: + ByteBuffer serialized = graph.asFlatBuffers(true); + Preconditions.checkNotNull(serialized, "Serialization failed? Null output"); + OpValidation.checkDeserializedEquality(graph, serialized, new TestCase(graph).testName(modelName).placeholderValues(inputs)); + + Nd4j.EPS_THRESHOLD = 1e-5; } @@ -285,7 +307,9 @@ public class TFGraphTestAllHelper { " must be null or both must be provided"); Nd4j.EPS_THRESHOLD = 1e-3; OpExecOrderListener listener = new OpExecOrderListener(); //Used to collect exec order - SameDiff graph = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener)); + Pair> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null); + SameDiff graph = p.getFirst(); + Map sdPredictions = p.getSecond(); //Collect coverage info about ops OpValidation.collectTensorflowImportCoverage(graph); @@ -313,7 +337,7 @@ public class TFGraphTestAllHelper { log.info("\n\tFORCING no check on " + varName); } else { assertArrayEquals("Shape not equal on node " + varName, tfValue.shape(), graph.getVariable(varName).getShape()); - INDArray sdVal = graph.getVariable(varName).getArr(); + INDArray sdVal = sdPredictions.get(varName); if(maxRelErrorOverride != null){ INDArray diff = Transforms.abs(tfValue.sub(sdVal), false); INDArray absErrorMask = diff.gte(minAbsErrorOverride); //value 1 if x[i] > minAbsError; value 0 otherwise. Used to get rid of 1e-30 vs. 1e-29 type failures @@ -362,30 +386,33 @@ public class TFGraphTestAllHelper { Nd4j.EPS_THRESHOLD = 1e-5; } - public static SameDiff getGraphAfterExec(String baseDir, String modelFilename, String modelName, Map inputs, - ExecuteWith executeWith, BiFunction graphLoaderFunction, List listeners) throws IOException { + public static Pair> getGraphAfterExec(String baseDir, String modelFilename, String modelName, Map inputs, + ExecuteWith executeWith, BiFunction graphLoaderFunction, List listeners, + Set requiredOutputs) throws IOException { log.info("\n\tRUNNING TEST " + modelName + "..."); SameDiff graph = graphLoaderFunction.apply(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getFile(), modelName); if(listeners != null){ graph.setListeners(listeners); } -// = TFGraphMapper.getInstance().importGraph(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getInputStream()); - //System.out.println(graph.summary()); + + if(requiredOutputs == null){ + requiredOutputs = graph.variableMap().keySet(); + } + + Map outMap = null; if (executeWith.equals(ExecuteWith.SAMEDIFF)) { - List outputs = graph.outputs(); - if(outputs.isEmpty()){ - //Edge case: no ops - List vars = graph.variables(); - outputs = new ArrayList<>(); - for(SDVariable v : vars) { - outputs.add(v.getVarName()); - } - } - if (!inputs.isEmpty()) { - graph.exec(inputs, outputs); //This is expected to be just one result - } else { - graph.exec(Collections.emptyMap(), outputs); //there are graphs with no placeholders like g_00 - } + //Set memory manager - check that all arrays (other than the ones we requested as output) + CloseValidationMemoryMgr mmgr = new CloseValidationMemoryMgr(graph, new ArrayCloseMemoryMgr()); + long tid = Thread.currentThread().getId(); + if(!graph.getSessions().containsKey(tid)) + graph.getSessions().put(tid, new InferenceSession(graph)); + //Execute + graph.getSessions().get(tid).setMmgr(mmgr); + outMap = graph.output(inputs, new ArrayList<>(requiredOutputs)); + + //Check that all arrays were released + mmgr.assertAllReleasedExcept(outMap.values()); + graph.getSessions().clear(); } else if (executeWith.equals(ExecuteWith.LIBND4J)) { for (String input : inputs.keySet()) { graph.associateArrayWithVariable(inputs.get(input), graph.variableMap().get(input)); @@ -396,7 +423,6 @@ public class TFGraphTestAllHelper { val executioner = new NativeGraphExecutioner(); val results = executioner.executeGraph(graph, configuration); - //graph.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/non2d_1.fb")); } else if (executeWith.equals(ExecuteWith.JUST_PRINT)) { for (String input : inputs.keySet()) { graph.associateArrayWithVariable(inputs.get(input), graph.variableMap().get(input)); @@ -405,7 +431,8 @@ public class TFGraphTestAllHelper { val string = graph.asFlatPrint(); log.info("Graph structure: \n{}", string); } - return graph; + + return new Pair<>(graph, outMap); } private static String[] modelDirNames(String base_dir, ExecuteWith executeWith, String modelFileName) throws IOException { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java index 1da31d863..3a39dac37 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java @@ -53,7 +53,7 @@ public class TFGraphTestList { public static String[] modelNames = new String[]{ // "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME" - "cnn2d_layers/channels_last_b1_k2_s1_d1_SAME_elu" + "accumulate_n/rank0" }; @After diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java index 8429637fd..05edef2b8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java @@ -255,7 +255,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we OpValidationSuite.ignoreFailing(); } -// if(!modelName.startsWith("ssd")){ +// if(!modelName.startsWith("mobilenet_v2_1.0_224")){ // OpValidationSuite.ignoreFailing(); // } currentTestDir = testDir.newFolder(); @@ -282,9 +282,12 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we } //Libnd4j exec: + /* + //AB 2019/10/19 - Libnd4j execution disabled pending execution rewrite currentTestDir = testDir.newFolder(); log.info("----- Libnd4j Exec: {} -----", modelName); TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.LIBND4J, LOADER, maxRE, minAbs); + */ } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java index a501f9ff4..745f5f8fd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java @@ -39,7 +39,6 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.api.ops.impl.controlflow.If; import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -100,7 +99,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testSingleExample_1() { - val g =TFGraphMapper.getInstance().importGraph(new File("C:\\Users\\raver\\Downloads\\mnist.pb")); + val g = TFGraphMapper.importGraph(new File("C:\\Users\\raver\\Downloads\\mnist.pb")); val array = Nd4j.ones(1, 28, 28); g.associateArrayWithVariable(array, "flatten_1_input"); @@ -113,12 +112,12 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testAssertImport_1() { - val graph = TFGraphMapper.getInstance().importGraph(new File("C:\\Users\\raver\\Downloads\\test.pb")); + val graph = TFGraphMapper.importGraph(new File("C:\\Users\\raver\\Downloads\\test.pb")); } @Test public void testArgMaxImport_2() throws Exception { - val graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("/tf_graphs/examples/reductions/argmax3,4,5_-1/frozen_graph.pbtxt").getInputStream()); + val graph = TFGraphMapper.importGraph(new ClassPathResource("/tf_graphs/examples/reductions/argmax3,4,5_-1/frozen_graph.pbtxt").getInputStream()); graph.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/argmax_macos.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build(), true); @@ -127,7 +126,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testArgMaxImport_1() throws Exception { - val graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("/tf_graphs/argmax.pb.txt").getInputStream()); + val graph = TFGraphMapper.importGraph(new ClassPathResource("/tf_graphs/argmax.pb.txt").getInputStream()); log.info(graph.asFlatPrint()); val result = graph.execAndEndResult(); @@ -137,68 +136,6 @@ public class TensorFlowImportTest extends BaseNd4jTest { assertEquals(exp, result); } - - @Test - public void testIfStatementNodes() throws Exception { - // /home/agibsonccc/code/dl4j-test-resources/src/main/resources/tf_graphs/examples/simple_cond/frozen_graph.pbtxt - val resourceInputStream = new ClassPathResource("/tf_graphs/examples/simple_cond/frozen_model.pb").getInputStream(); - val mapper = TFGraphMapper.getInstance(); - val readGraph = TFGraphMapper.getInstance().parseGraphFrom(resourceInputStream); - val nodes = mapper.nodesByName(readGraph); - /** - * Work backwards starting fom the condition id (usually a name containing condid/pred_id: - - */ - - val firstInput = nodes.get("cond5/Merge"); - val ifNodes = mapper.nodesForIf(firstInput,readGraph); - assertEquals(5,ifNodes.getFalseNodes().size()); - assertEquals(5,ifNodes.getTrueNodes().size()); - assertEquals(10,ifNodes.getCondNodes().size()); - - - val secondInput = nodes.get("cond6/Merge"); - val ifNodesTwo = mapper.nodesForIf(secondInput,readGraph); - assertEquals(5,ifNodesTwo.getFalseNodes().size()); - assertEquals(5,ifNodesTwo.getTrueNodes().size()); - assertEquals(6,ifNodesTwo.getCondNodes().size()); - - - val parentContext = SameDiff.create(); - val ifStatement = new If(); - ifStatement.initFromTensorFlow(firstInput,parentContext,Collections.emptyMap(),readGraph); - assertNotNull(ifStatement.getLoopBodyExecution()); - assertNotNull(ifStatement.getFalseBodyExecution()); - assertNotNull(ifStatement.getPredicateExecution()); - - } - - @Test - @Ignore - public void testIfIgnoreWhileMerge() throws Exception { - val resourceInputStream = new ClassPathResource("/tf_graphs/examples/simple_while/frozen_model.pb").getInputStream(); - val mapper = TFGraphMapper.getInstance(); - val readGraph = TFGraphMapper.getInstance().parseGraphFrom(resourceInputStream); - val nodes = mapper.nodesByName(readGraph); - val firstInput = nodes.get("output/Merge"); - assertNotNull(firstInput); - assertFalse(mapper.isOpIgnoreException(firstInput)); - - val resourceInputStreamIf = new ClassPathResource("/tf_graphs/examples/simple_cond/frozen_model.pb").getInputStream(); - val readGraphIf = TFGraphMapper.getInstance().parseGraphFrom(resourceInputStreamIf); - val nodesif = mapper.nodesByName(readGraphIf); - /** - * Work backwards starting fom the condition id (usually a name containing condid/pred_id: - - */ - - val secondInput = nodesif.get("cond5/Merge"); - assertNotNull(secondInput); - assertTrue(mapper.isOpIgnoreException(secondInput)); - - } - - @Test public void testHashEquality1() { long hash = HashUtil.getLongHash("Conv2D"); @@ -222,7 +159,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Ignore public void importGraph1() throws Exception { - SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_add_2.pb.txt").getInputStream()); + SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_add_2.pb.txt").getInputStream()); assertNotNull(graph); @@ -245,7 +182,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Ignore public void importGraph2() throws Exception { - SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream()); + SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream()); assertNotNull(graph); } @@ -254,7 +191,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Ignore public void importGraph3() throws Exception { - SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_log_reg.pb.txt").getInputStream()); + SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_log_reg.pb.txt").getInputStream()); assertNotNull(graph); } @@ -262,7 +199,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Ignore public void testImportIris() throws Exception { - SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/train_iris.pb").getInputStream()); + SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/train_iris.pb").getInputStream()); assertNotNull(graph); } @@ -271,7 +208,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Ignore public void importGraph4() throws Exception { - SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_multiply.pb.txt").getInputStream()); + SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_multiply.pb.txt").getInputStream()); assertNotNull(graph); @@ -306,7 +243,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { val rawGraph = GraphDef.parseFrom(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream()); val nodeNames = rawGraph.getNodeList().stream().map(node -> node.getName()).collect(Collectors.toList()); System.out.println(nodeNames); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream()); val convNode = tg.getVariable("conv2d/kernel"); @@ -322,14 +259,14 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testIntermediate2() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_lstm.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_lstm.pb").getInputStream()); } @Test public void testIntermediate1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream()); assertTrue(tg.getVariable("input") != null); // assertTrue(tg.getVariableSpace().getVariable("input").isPlaceholder()); @@ -348,7 +285,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testIntermediateLoop1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/simple_while.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/simple_while.pb.txt").getInputStream()); assertNotNull(tg); @@ -363,7 +300,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Ignore public void testWeirdConvImport() { - val tg = TFGraphMapper.getInstance().importGraph(new File("/home/agibsonccc/code/raver_tfimport_test1/profiling_conv.pb.txt")); + val tg = TFGraphMapper.importGraph(new File("/home/agibsonccc/code/raver_tfimport_test1/profiling_conv.pb.txt")); assertNotNull(tg); } @@ -371,7 +308,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testIntermediateLoop3() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/nested_while.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/nested_while.pb.txt").getInputStream()); assertNotNull(tg); @@ -397,7 +334,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Ignore public void testIntermediateStridedSlice1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_slice.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_slice.pb.txt").getInputStream()); assertNotNull(tg); @@ -473,7 +410,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Ignore public void testIntermediateTensorArraySimple1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream()); tg.setArrayForVariable("input_matrix",Nd4j.ones(3,2)); assertNotNull(tg); @@ -500,7 +437,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Ignore public void testIntermediateTensorArrayLoop1() throws Exception { val input = Nd4j.linspace(1, 10, 10, DataType.FLOAT).reshape(5, 2); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array_loop.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_loop.pb.txt").getInputStream()); tg.setArrayForVariable("input_matrix",input); assertNotNull(tg); @@ -545,7 +482,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testIntermediateReduction() throws Exception { Nd4j.create(1); - SameDiff tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream()); + SameDiff tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream()); SDVariable sumResultVar = tg.getVariable("Sum"); /* val func = tg.getFunctionForVertexId(sumResultVar.getVertexId()); @@ -709,7 +646,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } */ - SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/bias_add/frozen_model.pb").getInputStream()); + SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/bias_add/frozen_model.pb").getInputStream()); assertNotNull(graph); INDArray input = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4); @@ -724,7 +661,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testImportMapping1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/ae_00/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/ae_00/frozen_model.pb").getInputStream()); val variables = new HashMap(); for (val var : tg.variables()) { @@ -744,7 +681,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testCondMapping1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream()); assertNotNull(tg); tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0_1.fb")); @@ -759,7 +696,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testCondMapping2() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream()); assertNotNull(tg); val input = Nd4j.create(2, 2).assign(-1); @@ -776,7 +713,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testWhileMapping1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); assertNotNull(tg); val input = Nd4j.create(2, 2).assign(1); tg.associateArrayWithVariable(input, tg.getVariable("input_0")); @@ -795,7 +732,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testWhileMapping2() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); assertNotNull(tg); val input = Nd4j.scalar(4.0); tg.associateArrayWithVariable(input, tg.getVariable("input_1")); @@ -813,7 +750,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testWhileMapping3() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); assertNotNull(tg); val input = Nd4j.scalar(9.0); tg.associateArrayWithVariable(input, tg.getVariable("input_1")); @@ -832,7 +769,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testWhileDualMapping1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(2, 2).assign(-4.0); val input1 = Nd4j.scalar(1.0); @@ -852,7 +789,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testWhileDualMapping2() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(2, 2).assign(-9.0); val input1 = Nd4j.scalar(1.0); @@ -873,7 +810,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testMixedWhileCond1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_nested/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_nested/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(2, 2).assign(1.0); val input1 = Nd4j.create(3, 3).assign(2.0); @@ -896,7 +833,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Ignore public void testProfConv() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new File("/home/raver119/develop/workspace/models/profiling_conv.pb.txt")); + val tg = TFGraphMapper.importGraph(new File("/home/raver119/develop/workspace/models/profiling_conv.pb.txt")); assertNotNull(tg); tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/profiling_conv.fb")); @@ -907,7 +844,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_matrix_diag() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/partition_stitch_misc/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/partition_stitch_misc/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(2, 5, 4).assign(1.0); @@ -926,7 +863,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_tensor_dot_misc() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/tensor_dot_misc/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/tensor_dot_misc/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(36, 3, 4, 5).assign(1.0); @@ -943,7 +880,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_transpose() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/transpose/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/transpose/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(new double[]{0.98114507, 0.96400015, 0.58669623, 0.60073098, 0.75425418, 0.44258752, 0.76373084, 0.96593234, 0.34067846}, new int[] {3, 3}); @@ -960,7 +897,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_simpleif_0() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(new float[] {1, 2, 3, 4}, new int[] {2, 2}); @@ -977,7 +914,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_ae_00() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/ae_00/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/ae_00/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(new double[] {0.98174960, 0.44406342, 0.50100771, 1.00000000, -0.94038386, 0.46501783, -0.49040590, 0.98153842, -0.00198260, 0.49108310, -0.06085236, 0.93523693, -0.05857396, -0.46633510, -0.02806635, -0.96879626, -0.03938015, -0.51578135, -0.06333921, -1.00000000}, new int[] {5, 4}); @@ -992,7 +929,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_expand_dim() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(new double[] {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743}, new int[] {3, 4}); @@ -1007,7 +944,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_reduce_dim_false() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream()); assertNotNull(tg); @@ -1019,7 +956,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_reduce_dim_true() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim_true.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/reduce_dim_true.pb.txt").getInputStream()); assertNotNull(tg); tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/reduce_dim_true.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build(), true); @@ -1027,7 +964,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testTensorArray_119_1() throws Exception { - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream()); assertNotNull(tg); val input_matrix = Nd4j.ones(3, 2); @@ -1040,7 +977,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testTensorArray_119_2() throws Exception { - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array_read.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_read.pb.txt").getInputStream()); assertNotNull(tg); val input_matrix = Nd4j.ones(3, 2); @@ -1057,7 +994,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testTensorArray_119_3() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array_unstack.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_unstack.pb.txt").getInputStream()); assertNotNull(tg); val array = tg.execSingle(Collections.emptyMap(), tg.outputs().get(0)); @@ -1069,7 +1006,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testTensorArray_119_4() throws Exception { - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array_loop.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_loop.pb.txt").getInputStream()); assertNotNull(tg); val input_matrix = Nd4j.linspace(1, 10, 10, DataType.FLOAT).reshape(5, 2); @@ -1084,7 +1021,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testLossImport_1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/losses/log_loss_rank2_axis1_SUM_OVER_BATCH_SIZE/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/losses/log_loss_rank2_axis1_SUM_OVER_BATCH_SIZE/frozen_model.pb").getInputStream()); tg.execAndEndResult(); } @@ -1092,7 +1029,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testG_1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/g_08/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/g_08/frozen_model.pb").getInputStream()); val g = tg.asFlatBuffers(true); } @@ -1101,7 +1038,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testBoolImport_1() throws Exception { Nd4j.create(1); for (int e = 0; e < 1000; e++){ - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/reduce_any/rank0/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/reduce_any/rank0/frozen_model.pb").getInputStream()); Map result = tg.exec(Collections.emptyMap(), tg.outputs()); @@ -1113,7 +1050,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testLogical_1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/transforms/logicalxor_3,4_3,4/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/transforms/logicalxor_3,4_3,4/frozen_model.pb").getInputStream()); tg.execAndEndResult(); } @@ -1123,18 +1060,18 @@ public class TensorFlowImportTest extends BaseNd4jTest { // tf_graphs/examples/ssd_inception_v2_coco_2018_01_28/frozen_inference_graph.pb Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/ssd_inception_v2_coco_2018_01_28/frozen_inference_graph.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/ssd_inception_v2_coco_2018_01_28/frozen_inference_graph.pb").getInputStream()); assertNotNull(tg); } @Test(expected = ND4JIllegalStateException.class) public void testNonFrozenGraph1() throws Exception { - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/unfrozen_simple_ae.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/unfrozen_simple_ae.pb").getInputStream()); } @Test public void testRandomGraph() throws Exception { - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/assert_equal/scalar_float32/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/assert_equal/scalar_float32/frozen_model.pb").getInputStream()); assertNotNull(tg); tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/scalar_float32.fb")); @@ -1142,7 +1079,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testRandomGraph2() throws Exception { - val tg = TFGraphMapper.getInstance().importGraph(new File("c:\\develop\\mobilenet_v2_1.0_224_frozen.pb")); + val tg = TFGraphMapper.importGraph(new File("c:\\develop\\mobilenet_v2_1.0_224_frozen.pb")); assertNotNull(tg); tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/mobilenet_v2.fb")); @@ -1151,7 +1088,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Ignore public void testRandomGraph3() throws Exception { - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/assert_equal/3,4_3,4_float32/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/assert_equal/3,4_3,4_float32/frozen_model.pb").getInputStream()); assertNotNull(tg); log.info("{}", tg.asFlatPrint()); @@ -1161,7 +1098,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testControlDependencies1() throws Exception { - SameDiff sd = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/cond/cond_true/frozen_model.pb").getInputStream()); + SameDiff sd = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/cond/cond_true/frozen_model.pb").getInputStream()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java index 1035bda7a..5b4c84b4a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java @@ -103,7 +103,7 @@ public class ImportModelDebugger { File modelFile = new File("C:\\Temp\\TF_Graphs\\cifar10_gan_85\\tf_model.pb"); File rootDir = new File("C:\\Temp\\TF_Graphs\\cifar10_gan_85"); - SameDiff sd = TFGraphMapper.getInstance().importGraph(modelFile); + SameDiff sd = TFGraphMapper.importGraph(modelFile); ImportDebugListener l = ImportDebugListener.builder(rootDir) .checkShapesOnly(true) diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/collections/WeakIdentityHashMap.java b/nd4j/nd4j-common/src/main/java/org/nd4j/collections/WeakIdentityHashMap.java new file mode 100644 index 000000000..c336befd7 --- /dev/null +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/collections/WeakIdentityHashMap.java @@ -0,0 +1,161 @@ +package org.nd4j.collections; + +import lombok.*; + +import java.lang.ref.Reference; +import java.lang.ref.ReferenceQueue; +import java.lang.ref.WeakReference; +import java.util.*; + +/** + * A hash map implementation with weak identity keys. + * For details, see {@link WeakHashMap} and {@link IdentityHashMap} + * + * @param Key type + * @param Value type + * @author Alex Black + */ +public class WeakIdentityHashMap implements Map { + + protected final Map, V> map; + protected final ReferenceQueue refQueue; + + public WeakIdentityHashMap(){ + map = new HashMap<>(); + refQueue = new ReferenceQueue<>(); + } + + //Clear references to any map keys that have been GC'd + protected void clearReferences(){ + Reference r; + while((r = refQueue.poll()) != null){ + map.remove(r); + } + } + + @Override + public int size() { + clearReferences(); + return map.size(); + } + + @Override + public boolean isEmpty() { + clearReferences(); + return map.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + clearReferences(); + return map.containsKey(new KeyRef<>(key)); + } + + @Override + public boolean containsValue(Object value) { + clearReferences(); + return map.containsValue(value); + } + + @Override + public V get(Object key) { + clearReferences(); + return map.get(new KeyRef<>(key)); + } + + @Override + public V put(K key, V value) { + clearReferences(); + map.put(new KeyRef<>(key), value); + return value; + } + + @Override + public V remove(Object key) { + clearReferences(); + return map.remove(new KeyRef<>(key)); + } + + @Override + public void putAll(Map m) { + clearReferences(); + for(Map.Entry e : m.entrySet()){ + map.put(new KeyRef<>(e.getKey()), e.getValue()); + } + } + + @Override + public void clear() { + map.clear(); + clearReferences(); + } + + @Override + public Set keySet() { + clearReferences(); + Set ret = new HashSet<>(); + for(KeyRef k : map.keySet() ){ + K key = k.get(); + if(key != null) + ret.add(key); + } + return ret; + } + + @Override + public Collection values() { + clearReferences(); + return map.values(); + } + + @Override + public Set> entrySet() { + clearReferences(); + Set> ret = new HashSet<>(); + for(Map.Entry, V> e : map.entrySet()){ + K k = e.getKey().get(); + if(k != null){ + ret.add(new Entry(k, e.getValue())); + } + } + return ret; + } + + + protected static class KeyRef extends WeakReference { + private final int hash; + public KeyRef(@NonNull K referent) { + super(referent); + this.hash = System.identityHashCode(referent); + } + + @Override + public int hashCode(){ + return hash; + } + + @Override + public boolean equals(Object o){ + if(this == o){ + return true; + } + if(o instanceof WeakReference){ + return this.get() == ((WeakReference) o).get(); + } + return false; + } + } + + @Data + @AllArgsConstructor + protected static class Entry implements Map.Entry { + protected K key; + protected V value; + + @Override + public V setValue(V value){ + this.value = value; + return value; + } + } +} diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java b/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java index e063d16ae..b1ea7e76d 100644 --- a/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java +++ b/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java @@ -147,7 +147,7 @@ public class GraphInferenceGrpcClient { val arrOff = array.toFlatArray(builder); byte variableType = 0; //TODO is this OK here? - val varOff = FlatVariable.createFlatVariable(builder, idPair, nameOff, FlatBuffersMapper.getDataTypeAsByte(array.dataType()),0, arrOff, -1, variableType); + val varOff = FlatVariable.createFlatVariable(builder, idPair, nameOff, FlatBuffersMapper.getDataTypeAsByte(array.dataType()),0, arrOff, -1, variableType, 0, 0, 0); ins[cnt++] = varOff; } diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java b/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java index d8dd3f4df..1c84da010 100644 --- a/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java +++ b/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java @@ -43,7 +43,7 @@ public class GraphInferenceGrpcClientTest { val graphId = RandomUtils.nextLong(0, Long.MAX_VALUE); // preparing and registering graph (it's optional, and graph might be embedded into Docker image - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream()); assertNotNull(tg); client.registerGraph(graphId, tg, ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build()); @@ -66,7 +66,7 @@ public class GraphInferenceGrpcClientTest { val graphId = RandomUtils.nextLong(0, Long.MAX_VALUE); // preparing and registering graph (it's optional, and graph might be embedded into Docker image - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream()); assertNotNull(tg); client.registerGraph(graphId, tg, ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build()); diff --git a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/ProtoBufToFlatBufConversion.java b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/ProtoBufToFlatBufConversion.java index 73bd00036..ee53c37e2 100644 --- a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/ProtoBufToFlatBufConversion.java +++ b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/ProtoBufToFlatBufConversion.java @@ -56,7 +56,7 @@ public class ProtoBufToFlatBufConversion { */ public static void convert(String inFile, String outFile) throws IOException, org.nd4j.linalg.exception.ND4JIllegalStateException { - SameDiff tg = TFGraphMapper.getInstance().importGraph(new File(inFile)); + SameDiff tg = TFGraphMapper.importGraph(new File(inFile)); tg.asFlatFile(new File(outFile)); } @@ -90,7 +90,7 @@ public class ProtoBufToFlatBufConversion { }; - SameDiff sd = TFGraphMapper.getInstance().importGraph(new File(inFile), m, filter); + SameDiff sd = TFGraphMapper.importGraph(new File(inFile), m, filter); SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/mul")) // .../dropout/mul