From d333d29099247876d86a3473ab7e0be42eebbb6b Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 26 Oct 2019 12:38:08 +1100 Subject: [PATCH] SameDiff cleanup and fixes (#12) * #8160 Remove resolvePrepertiesFromSameDiffBeforeExecution Signed-off-by: AlexDBlack * SameDiff API cleanup Signed-off-by: AlexDBlack * More SameDiff cleanup Signed-off-by: AlexDBlack * Small fixes Signed-off-by: AlexDBlack * #8248 Switch SameDiff variable init from lazy to creation time for more predictable behaviour Signed-off-by: AlexDBlack * #8252 TanhDerivative javadoc Signed-off-by: AlexDBlack * #8225 Deconvolution2D input validation Signed-off-by: AlexDBlack * #8265 Switch SameDiff.outputs() to user settable, instead of unreliable 'best guess' Signed-off-by: AlexDBlack * #8224 SameDiff.zero and .one create constants, not variables Signed-off-by: AlexDBlack * More cleanup and fixes Signed-off-by: AlexDBlack * Small test fix Signed-off-by: AlexDBlack * Small fix Signed-off-by: AlexDBlack * DL4J SameDiff fixes Signed-off-by: AlexDBlack * Re-add hack for Deconvolution2DLayer until #8315 is resolved Signed-off-by: AlexDBlack * #8270 Move CUDA device/version logging to Java; can be disabled via existing org.nd4j.log.initialization system property Signed-off-by: AlexDBlack * All ND4J init logging checks system property Signed-off-by: AlexDBlack * Small tweak Signed-off-by: AlexDBlack * Remove redundant device logging Signed-off-by: AlexDBlack * One more fix Signed-off-by: AlexDBlack * UX improvements Signed-off-by: AlexDBlack * Deconv fix Signed-off-by: AlexDBlack * Add deconv tests Signed-off-by: AlexDBlack * Cleanup Signed-off-by: AlexDBlack * Remove debug code Signed-off-by: AlexDBlack --- .../convolution/ConvolutionLayerTest.java | 19 + .../CompareTrainingImplementations.java | 20 +- .../layers/samediff/SameDiffLambdaVertex.java | 2 +- .../convolution/Deconvolution2DLayer.java | 4 +- .../layers/samediff/SameDiffGraphVertex.java | 4 +- .../nn/layers/samediff/SameDiffLayer.java | 2 +- .../layers/samediff/SameDiffOutputLayer.java | 24 +- libnd4j/blas/Environment.cpp | 27 +- libnd4j/blas/Environment.h | 4 + .../declarable/generic/nn/convo/deconv2d.cpp | 6 +- .../functions/DifferentialFunction.java | 79 +- .../DifferentialFunctionFactory.java | 13 +- .../listeners/ListenerEvaluations.java | 4 +- .../autodiff/listeners/ListenerVariables.java | 2 +- .../autodiff/listeners/impl/UIListener.java | 2 +- .../listeners/records/EvaluationRecord.java | 12 +- .../autodiff/listeners/records/History.java | 16 +- .../autodiff/listeners/records/LossCurve.java | 8 +- .../nd4j/autodiff/samediff/SDVariable.java | 100 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 1090 +++-------------- .../autodiff/samediff/TrainingConfig.java | 4 +- .../samediff/config/BatchOutputConfig.java | 24 +- .../samediff/config/EvaluationConfig.java | 6 +- .../samediff/config/OutputConfig.java | 2 +- .../samediff/internal/AbstractSession.java | 10 +- .../samediff/internal/InferenceSession.java | 74 +- .../samediff/internal/TrainingSession.java | 4 +- .../nd4j/autodiff/samediff/ops/SDBaseOps.java | 42 +- .../autodiff/samediff/ops/SDValidation.java | 26 +- .../samediff/serde/FlatBuffersMapper.java | 2 +- .../transform/GraphTransformUtil.java | 8 +- .../autodiff/samediff/transform/SubGraph.java | 2 +- .../samediff/transform/SubGraphPredicate.java | 4 +- .../autodiff/validation/GradCheckUtil.java | 54 +- .../autodiff/validation/OpValidation.java | 6 +- .../nd4j/autodiff/validation/TestCase.java | 4 +- .../java/org/nd4j/graph/ui/LogFileWriter.java | 12 +- .../converters/ImportClassMapping.java | 2 +- .../imports/graphmapper/tf/TFGraphMapper.java | 2 - .../linalg/api/ops/BaseBroadcastBoolOp.java | 8 - .../nd4j/linalg/api/ops/BaseBroadcastOp.java | 9 - .../linalg/api/ops/BaseIndexAccumulation.java | 17 +- .../java/org/nd4j/linalg/api/ops/BaseOp.java | 4 +- .../org/nd4j/linalg/api/ops/BaseReduceOp.java | 6 +- .../nd4j/linalg/api/ops/BaseScalarBoolOp.java | 5 +- .../org/nd4j/linalg/api/ops/BaseScalarOp.java | 5 +- .../nd4j/linalg/api/ops/BaseTransformOp.java | 32 +- .../nd4j/linalg/api/ops/DynamicCustomOp.java | 2 +- .../impl/layers/ExternalErrorsFunction.java | 8 +- .../ops/impl/layers/convolution/DeConv2D.java | 4 +- .../ops/impl/layers/convolution/DeConv3D.java | 4 +- .../nd4j/linalg/api/ops/impl/reduce/Mmul.java | 6 - .../linalg/api/ops/impl/shape/Concat.java | 5 +- .../linalg/api/ops/impl/shape/Transpose.java | 12 +- .../shape/tensorops/TensorArrayConcat.java | 2 +- .../shape/tensorops/TensorArrayGather.java | 2 +- .../impl/shape/tensorops/TensorArrayRead.java | 2 +- .../ops/impl/transforms/CheckNumerics.java | 4 +- .../api/ops/impl/transforms/Constant.java | 91 -- .../api/ops/impl/transforms/MaxOut.java | 2 +- .../transforms/gradient/TanhDerivative.java | 12 +- .../linalg/api/ops/random/BaseRandomOp.java | 11 +- .../java/org/nd4j/linalg/factory/Nd4j.java | 31 +- .../org/nd4j/nativeblas/NativeOpsHolder.java | 8 +- .../jita/allocator/impl/MemoryTracker.java | 2 +- .../nd4j/linalg/jcublas/JCublasBackend.java | 40 + .../ops/executioner/CudaExecutioner.java | 8 - .../nd4j/linalg/cpu/nativecpu/CpuBackend.java | 5 + .../java/org/nd4j/autodiff/TestSessions.java | 6 +- .../execution/GraphExecutionerTest.java | 2 +- .../opvalidation/LayerOpValidation.java | 20 +- .../opvalidation/MiscOpValidation.java | 49 +- .../opvalidation/ReductionOpValidation.java | 14 +- .../opvalidation/RnnOpValidation.java | 12 +- .../opvalidation/ShapeOpValidation.java | 77 +- .../opvalidation/TransformOpValidation.java | 217 ++-- .../samediff/FailingSameDiffTests.java | 13 +- .../samediff/FlatBufferSerdeTest.java | 14 +- .../samediff/GraphTransformUtilTests.java | 24 +- .../autodiff/samediff/NameScopeTests.java | 28 +- .../SameDiffSpecifiedLossVarsTests.java | 12 +- .../nd4j/autodiff/samediff/SameDiffTests.java | 417 +++---- .../nd4j/autodiff/ui/FileReadWriteTests.java | 2 + .../org/nd4j/autodiff/ui/UIListenerTest.java | 2 +- .../java/org/nd4j/imports/ExecutionTests.java | 2 +- .../nd4j/imports/TFGraphs/BERTGraphTest.java | 12 +- .../TFGraphs/ValidateZooModelPredictions.java | 8 +- .../nd4j/imports/TensorFlowImportTest.java | 44 +- .../listeners/ImportModelDebugger.java | 2 +- .../nd4j/linalg/convolution/DeconvTests.java | 93 ++ .../org/nd4j/linalg/factory/Nd4jBackend.java | 15 +- .../nd4j/remote/serving/SameDiffServlet.java | 2 +- 92 files changed, 1204 insertions(+), 1956 deletions(-) delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Constant.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java index 86278a793..1c4b764bd 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java @@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.exception.DL4JException; +import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -693,4 +694,22 @@ public class ConvolutionLayerTest extends BaseDL4JTest { INDArray out = net.output(in); assertArrayEquals(new long[]{2,7,6}, out.shape()); } + + @Test + public void testDeconvBadInput(){ + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .list() + .layer(new Deconvolution2D.Builder().nIn(5).nOut(3).build()) + .build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + INDArray badInput = Nd4j.create(DataType.FLOAT, 1, 10, 5, 5); + try { + net.output(badInput); + } catch (DL4JInvalidInputException e){ + String msg = e.getMessage(); + assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels")); + } + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java index 9570166ed..12564f01a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java @@ -86,10 +86,10 @@ public class CompareTrainingImplementations extends BaseDL4JTest { SDVariable label = sd.placeHolder("label", DataType.DOUBLE, -1, 3); SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.DOUBLE, 4, 10); - SDVariable b0 = sd.zero("b0", 1, 10); + SDVariable b0 = sd.var("b0", Nd4j.create(DataType.DOUBLE, 1, 10)); SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.DOUBLE, 10, 3); - SDVariable b1 = sd.zero("b1", 1, 3); + SDVariable b1 = sd.var("b1", Nd4j.create(DataType.DOUBLE, 1, 3)); SDVariable z0 = in.mmul(w0).add(b0); SDVariable a0 = sd.nn().tanh(z0); @@ -172,8 +172,8 @@ public class CompareTrainingImplementations extends BaseDL4JTest { Map placeholders = new HashMap<>(); placeholders.put("input", f); placeholders.put("label", l); - Map map = sd.output(placeholders, lossMse.getVarName(), a1.getVarName()); - INDArray outSd = map.get(a1.getVarName()); + Map map = sd.output(placeholders, lossMse.name(), a1.name()); + INDArray outSd = map.get(a1.name()); INDArray outDl4j = net.output(f); assertEquals(testName, outDl4j, outSd); @@ -187,7 +187,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest { //Check score double scoreDl4j = net.score(); - double scoreSd = map.get(lossMse.getVarName()).getDouble(0) + sd.calcRegularizationScore(); + double scoreSd = map.get(lossMse.name()).getDouble(0) + sd.calcRegularizationScore(); assertEquals(testName, scoreDl4j, scoreSd, 1e-6); double lossRegScoreSD = sd.calcRegularizationScore(); @@ -197,15 +197,15 @@ public class CompareTrainingImplementations extends BaseDL4JTest { //Check gradients (before updater applied) Map grads = net.gradient().gradientForVariable(); - sd.execBackwards(placeholders); + Map gm = sd.calculateGradients(placeholders, b1.name(), w1.name(), b0.name(), w0.name()); //Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added in fitting only //We can check correctness though with training param checks later if(l1Val == 0 && l2Val == 0 && wdVal == 0) { - assertEquals(testName, grads.get("1_b"), b1.getGradient().getArr()); - assertEquals(testName, grads.get("1_W"), w1.getGradient().getArr()); - assertEquals(testName, grads.get("0_b"), b0.getGradient().getArr()); - assertEquals(testName, grads.get("0_W"), w0.getGradient().getArr()); + assertEquals(testName, grads.get("1_b"), gm.get(b1.name())); + assertEquals(testName, grads.get("1_W"), gm.get(w1.name())); + assertEquals(testName, grads.get("0_b"), gm.get(b0.name())); + assertEquals(testName, grads.get("0_W"), gm.get(w0.name())); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java index f10ae5bad..6890f83e8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java @@ -65,7 +65,7 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex { defineVertex(temp, tempInputs); List list = new ArrayList<>(); for (Integer i : tempInputs.map.keySet()) { - list.add(tempInputs.map.get(i).getVarName()); + list.add(tempInputs.map.get(i).name()); } params.defineInputs(list.toArray(new String[list.size()])); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java index 8a6c187c2..9601589bb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java @@ -176,8 +176,10 @@ public class Deconvolution2DLayer extends ConvolutionLayer { int outDepth = (int) weights.size(1); if (input.size(1) != inDepth && input.size(3) == inDepth) { + //TODO AB 2019/10/25 this is an ugly "pseudo-NHWC support" hack that needs to be removed ASAD + //https://github.com/eclipse/deeplearning4j/issues/8315 input = input.permute(0, 3, 1, 2); - } else if (input.size(1) != inDepth && input.size(3) != inDepth) { + } else if (input.size(1) != inDepth ) { String layerName = conf.getLayer().getLayerName(); if (layerName == null) layerName = "(not named)"; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java index 34799d6ad..3e1d1b831 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java @@ -192,7 +192,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex { String name = inputs.get(j); dLdIns[j] = sameDiff.grad(name).getArr(); - String gradName = sameDiff.grad(inputNames.get(j)).getVarName(); + String gradName = sameDiff.grad(inputNames.get(j)).name(); if(dLdIns[j] == null && fnName.equals(gradName)){ //Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders // So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here @@ -271,7 +271,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex { fn = sameDiff.f().externalErrors(layerOutput); fn.outputVariable(); - this.outputKey = outputVar.getVarName(); + this.outputKey = outputVar.name(); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java index 434a35adc..5723861c1 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java @@ -302,7 +302,7 @@ public class SameDiffLayer extends AbstractLayer { fn = sameDiff.f().externalErrors(layerOutput); fn.outputVariable(); - this.outputKey = outputVar.getVarName(); + this.outputKey = outputVar.name(); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java index e5ca125cd..e6d9c2a7e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java @@ -112,7 +112,7 @@ public class SameDiffOutputLayer extends AbstractLayer gradVarNames = new ArrayList<>(); - for(String s : paramTable.keySet()){ - gradVarNames.add(sameDiff.getVariable(s).getGradient().getVarName()); - } - gradVarNames.add(sameDiff.grad(INPUT_KEY).getVarName()); + gradVarNames.addAll(paramTable.keySet()); + gradVarNames.add(INPUT_KEY); Map phMap = new HashMap<>(); phMap.put(INPUT_KEY, input); phMap.put(LABELS_KEY, labels); - sameDiff.execBackwards(phMap, gradVarNames); + Map grads = sameDiff.calculateGradients(phMap, gradVarNames); for(String s : paramTable.keySet() ){ - INDArray sdGrad = sameDiff.grad(s).getArr(); + INDArray sdGrad = grads.get(s); INDArray dl4jGrad = gradTable.get(s); dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS g.gradientForVariable().put(s, dl4jGrad); + if(sdGrad.closeable()){ + sdGrad.close(); + } } - dLdIn = sameDiff.grad(INPUT_KEY).getArr(); + dLdIn = grads.get(INPUT_KEY); } //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere sameDiff.clearPlaceholders(true); sameDiff.clearOpInputs(); - return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS + Pair p = new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS + if(dLdIn.closeable()) + dLdIn.close(); + return p; } /**Returns the parameters of the neural network as a flattened row vector @@ -297,7 +301,7 @@ public class SameDiffOutputLayer extends AbstractLayer& capabilities(); }; } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp index f69b6c0f9..caff807f8 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp @@ -66,8 +66,10 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { if(!isNCHW) output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + if(isSameMode){ // SAME + //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass + ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); + } NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 5ce25628e..32df3e69d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -442,22 +442,12 @@ public abstract class DifferentialFunction { setInstanceId(); if(sameDiff != null) { sameDiff.addArgsFor(args, this); - for (int i = 0; i < args.length; i++) { - if (args[i].isPlaceHolder()) { - sameDiff.addPropertyToResolve(this, args[i].getVarName()); - } - } } } public void replaceArg(int i, SDVariable newArg){ if(sameDiff != null){ sameDiff.replaceArgFor(i, newArg, this); - if(args()[i].isPlaceHolder() && !newArg.isPlaceHolder()){ - sameDiff.removePropertyToResolve(this, args()[i].getVarName()); - } else if(!args()[i].isPlaceHolder() && newArg.isPlaceHolder()){ - sameDiff.addPropertyToResolve(this, newArg.getVarName()); - } } } @@ -483,7 +473,7 @@ public abstract class DifferentialFunction { SDVariable[] outputVars = outputVariables(); String[] out = new String[outputVars.length]; for( int i=0; i 0 ? shape : null)).outputVariable(); - } - public SDVariable linspace(SDVariable lower, SDVariable upper, SDVariable count, DataType dt) { return new org.nd4j.linalg.api.ops.impl.shape.Linspace(sameDiff(), lower, upper, count, dt).outputVariable(); } @@ -1056,7 +1045,7 @@ public class DifferentialFunctionFactory { public SDVariable gradientBackwardsMarker(SDVariable iX) { - return new GradientBackwardsMarker(sameDiff(), iX, sameDiff.scalar(iX.getVarName() + "-pairgrad", 1.0)).outputVariable(); + return new GradientBackwardsMarker(sameDiff(), iX, sameDiff.scalar(iX.name() + "-pairgrad", 1.0)).outputVariable(); } public SDVariable abs(SDVariable iX) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerEvaluations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerEvaluations.java index 08722b06e..9bdc54dd7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerEvaluations.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerEvaluations.java @@ -178,7 +178,7 @@ public class ListenerEvaluations { * @param evaluations The evaluations to run */ public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) { - return trainEvaluation(variable.getVarName(), labelIndex, evaluations); + return trainEvaluation(variable.name(), labelIndex, evaluations); } /** @@ -202,7 +202,7 @@ public class ListenerEvaluations { * @param evaluations The evaluations to run */ public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) { - return validationEvaluation(variable.getVarName(), labelIndex, evaluations); + return validationEvaluation(variable.name(), labelIndex, evaluations); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java index 34b305001..33baf7099 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java @@ -167,7 +167,7 @@ public class ListenerVariables { String[] names = new String[variables.length]; for (int i = 0; i < variables.length; i++) - names[i] = variables[i].getVarName(); + names[i] = variables[i].name(); return requireVariables(op, names); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java index 452636d57..6c38c6c9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java @@ -226,7 +226,7 @@ public class UIListener extends BaseListener { List sdVars = sd.variables(); List varNames = new ArrayList<>(sdVars.size()); for(SDVariable v : sdVars){ - varNames.add(v.getVarName()); + varNames.add(v.name()); } if(varNames.size() != vars.size() || !varNames.containsAll(vars)){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java index b063e18a2..149ea5f2c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java @@ -91,7 +91,7 @@ public class EvaluationRecord { * @param param The target param/variable */ public List evaluations(SDVariable param) { - return evaluations(param.getVarName()); + return evaluations(param.name()); } /** @@ -105,7 +105,7 @@ public class EvaluationRecord { * Get the evaluation for param at the specified index */ public IEvaluation evaluation(SDVariable param, int index) { - return evaluation(param.getVarName(), index); + return evaluation(param.name(), index); } /** @@ -132,7 +132,7 @@ public class EvaluationRecord { * @param param The target param/variable */ public T evaluation(SDVariable param) { - return evaluation(param.getVarName()); + return evaluation(param.name()); } /** @@ -174,7 +174,7 @@ public class EvaluationRecord { * @param evalClass The type of evaluation to look for */ public > T evaluation(SDVariable param, Class evalClass) { - return evaluation(param.getVarName(), evalClass); + return evaluation(param.name(), evalClass); } /** @@ -209,7 +209,7 @@ public class EvaluationRecord { * @param metric The metric to calculate */ public double getValue(SDVariable param, IMetric metric) { - return getValue(param.getVarName(), metric); + return getValue(param.name(), metric); } /** @@ -235,7 +235,7 @@ public class EvaluationRecord { * @param metric The metric to calculate */ public double getValue(SDVariable param, int index, IMetric metric) { - return getValue(param.getVarName(), index, metric); + return getValue(param.name(), index, metric); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java index 92d1ce120..809400c0b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java @@ -125,7 +125,7 @@ public class History { * Only works if there is only one evaluation with the given metric for param */ public List trainingEval(SDVariable param, IMetric metric){ - return trainingEval(param.getVarName(), metric); + return trainingEval(param.name(), metric); } /** @@ -149,7 +149,7 @@ public class History { * Index determines the evaluation used not the epoch's results to return. */ public List trainingEval(SDVariable param, int index, IMetric metric){ - return trainingEval(param.getVarName(), index, metric); + return trainingEval(param.name(), index, metric); } /** @@ -184,7 +184,7 @@ public class History { * Only works if there is only one evaluation for param. */ public List trainingEval(SDVariable param){ - return trainingEval(param.getVarName()); + return trainingEval(param.name()); } /** @@ -208,7 +208,7 @@ public class History { * Index determines the evaluation used not the epoch's results to return. */ public List trainingEval(SDVariable param, int index){ - return trainingEval(param.getVarName(), index); + return trainingEval(param.name(), index); } /** @@ -230,7 +230,7 @@ public class History { * Only works if there is only one evaluation with the given metric for param */ public List validationEval(SDVariable param, IMetric metric){ - return validationEval(param.getVarName(), metric); + return validationEval(param.name(), metric); } /** @@ -254,7 +254,7 @@ public class History { * Index determines the evaluation used not the epoch's results to return. */ public List validationEval(SDVariable param, int index, IMetric metric){ - return validationEval(param.getVarName(), index, metric); + return validationEval(param.name(), index, metric); } /** @@ -289,7 +289,7 @@ public class History { * Only works if there is only one evaluation for param. */ public List validationEval(SDVariable param){ - return validationEval(param.getVarName()); + return validationEval(param.name()); } /** @@ -313,7 +313,7 @@ public class History { * Index determines the evaluation used not the epoch's results to return. */ public List validationEval(SDVariable param, int index){ - return validationEval(param.getVarName(), index); + return validationEval(param.name(), index); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java index 493950bbf..f3f24b98a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java @@ -116,7 +116,7 @@ public class LossCurve { * Return all mean loss values for a given variable */ public float[] meanLoss(@NonNull SDVariable loss){ - return meanLoss(loss.getVarName()); + return meanLoss(loss.name()); } /** @@ -143,7 +143,7 @@ public class LossCurve { * See {@link #meanLoss(int)} */ public float meanLoss(@NonNull SDVariable loss, int epoch){ - return meanLoss(loss.getVarName(), epoch); + return meanLoss(loss.name(), epoch); } /** @@ -162,7 +162,7 @@ public class LossCurve { * Return the mean loss value for a given variable on the last epoch. */ public float lastMeanLoss(@NonNull SDVariable loss){ - return lastMeanLoss(loss.getVarName()); + return lastMeanLoss(loss.name()); } /** @@ -189,7 +189,7 @@ public class LossCurve { * A positive delta means the loss is increasing, and a negative delta means it is decreasing. */ public double lastMeanDelta(SDVariable loss){ - return lastMeanDelta(loss.getVarName()); + return lastMeanDelta(loss.name()); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index f2818e7e4..6d9e34ed0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -59,10 +59,6 @@ public class SDVariable implements Serializable { @Setter protected VariableType variableType; - @Getter - @Setter - protected WeightInitScheme weightInitScheme; - @Setter(AccessLevel.NONE) protected long[] shape; @@ -75,9 +71,7 @@ public class SDVariable implements Serializable { // autogen_tag::sdvars::start - public SDVariable(@NonNull String varName, @NonNull VariableType varType, @NonNull SameDiff sameDiff, long[] shape, DataType dataType, WeightInitScheme weightInitScheme){ - Preconditions.checkState(weightInitScheme == null || varType == VariableType.VARIABLE, "Weight initalization schemes can only be applied to VARIABLE type" + - " SDVariables - variable \"%s\" is of type %s but was provided a weight initialization scheme %s", varName, varType, weightInitScheme); + public SDVariable(@NonNull String varName, @NonNull VariableType varType, @NonNull SameDiff sameDiff, long[] shape, DataType dataType){ Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName); varName = sameDiff.generateNewVarName(varName, 0, true); @@ -86,10 +80,25 @@ public class SDVariable implements Serializable { this.varName = varName; this.variableType = varType; this.dataType = dataType; - this.weightInitScheme = weightInitScheme; this.shape = shape; } + /** + * Get the name of the SDVariable + * @return Name of the variable + */ + public String name(){ + return varName; + } + + /** + * @deprecated Use {@link #name()} + */ + @Deprecated + public String getVarName(){ + return name(); + } + /** * Returns true if this variable is a place holder * @return @@ -102,30 +111,6 @@ public class SDVariable implements Serializable { return variableType == VariableType.CONSTANT; } - /** - * Allocate and return a new array - * based on the vertex id and weight initialization. - * @return the allocated array - */ - public INDArray storeAndAllocateNewArray() { - Preconditions.checkState(variableType == VariableType.VARIABLE, "Unable to allocate and store array for variable of type %s: only" + - " VARIABLE type variables can be initialized using this method", variableType); - - if(!sameDiff.arrayAlreadyExistsForVarName(varName)){ - long[] shape = getShape(); - INDArray arr = getWeightInitScheme().create(dataType(), shape); - sameDiff.associateArrayWithVariable(arr, this); - if(log.isTraceEnabled()){ - log.trace("Generated and stored new array for variable \"{}\": shape {}", getVarName(), Arrays.toString(arr.shape())); - } - return arr; - } - - //Variable type SDVariables: shape should never change (i.e., these are params in the net!) - INDArray ret = getArr(); - return ret; - } - /** * A getter for the allocated ndarray with this {@link SDVariable}. * @@ -155,30 +140,14 @@ public class SDVariable implements Serializable { public INDArray getArr(boolean enforceExistence){ if(sameDiff.arrayAlreadyExistsForVarName(getVarName())) return sameDiff.getArrForVarName(getVarName()); - if(variableType == VariableType.ARRAY){ throw new UnsupportedOperationException("Cannot get array for ARRAY type SDVariable - use SDVariable.exec or SameDiff.output instead"); } - - //initialize value if it's actually a scalar constant (zero or 1 typically...) - if(variableType == VariableType.VARIABLE && weightInitScheme != null && shape != null){ - INDArray arr = weightInitScheme.create(dataType, shape); - sameDiff.associateArrayWithVariable(arr, this); - if(log.isTraceEnabled()){ - log.trace("getArr() for variable \"{}\" allocated new array: shape {}", getVarName(), Arrays.toString(getShape())); - } - return arr; - } else if(sameDiff.getShapeForVarName(getVarName()) == null) { - if (enforceExistence) { - throw new IllegalStateException("Cannot get array for SDVariable \"" + getVarName() + "\": no array has" + - " been defined, and array shape cannot be calculated"); - } - if(log.isTraceEnabled()){ - log.trace("SDVariable.getArr(): could not get array for variable {}: shape is null", getVarName()); - } - return null; + INDArray ret = sameDiff.getArrForVarName(getVarName()); + if(enforceExistence && ret == null){ + throw new IllegalStateException("No array exists for variable \"" + name() + "\""); } - return sameDiff.getArrForVarName(getVarName()); + return ret; } @@ -215,21 +184,13 @@ public class SDVariable implements Serializable { * @return Shape of the variable */ public long[] getShape() { - if (variableType == VariableType.PLACEHOLDER && getArr() == null) { - if (shape != null) + if (variableType == VariableType.PLACEHOLDER ) { return shape; - else - return new long[0]; + } else if(variableType == VariableType.VARIABLE || variableType == VariableType.CONSTANT){ + return getArr().shape(); } - long[] initialShape = sameDiff.getShapeForVarName(getVarName()); - if(initialShape == null && variableType != VariableType.ARRAY) { - val arr = getArr(); - if(arr != null) - return arr.shape(); - } - - return initialShape; + return null; } public void setShape(long... shape){ @@ -1488,8 +1449,8 @@ public class SDVariable implements Serializable { * @return */ public INDArray eval() { - sameDiff.exec(null, getVarName()); - return getArr(); + Map m = sameDiff.output((Map)null, name()); + return m.get(name()); } @@ -1498,8 +1459,8 @@ public class SDVariable implements Serializable { * @return */ public INDArray eval(Map placeholders) { - sameDiff.exec(placeholders, getVarName()); - return getArr(); + Map m = sameDiff.output(placeholders, name()); + return m.get(name()); } @@ -1519,7 +1480,7 @@ public class SDVariable implements Serializable { */ public void addControlDependency(SDVariable controlDependency){ Variable vThis = sameDiff.getVariables().get(getVarName()); - Variable vCD = sameDiff.getVariables().get(controlDependency.getVarName()); + Variable vCD = sameDiff.getVariables().get(controlDependency.name()); //If possible: add control dependency on ops if(vThis.getOutputOfOp() != null && vCD.getOutputOfOp() != null ){ @@ -1729,7 +1690,6 @@ public class SDVariable implements Serializable { SDVariable v = new SDVariable(); v.varName = varName; v.variableType = variableType; - v.weightInitScheme = weightInitScheme; v.shape = shape == null ? null : shape.clone(); v.dataType = dataType; v.sameDiff = sd; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 1bcb3aedb..284c3e6ac 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -114,7 +114,7 @@ public class SameDiff extends SDBaseOps { protected static final String GRAD_FN_KEY = "grad"; //Fields for graph structure and execution - @Getter //TODO use package private instead of public getters? + @Getter private final Map variables = new LinkedHashMap<>(); //Use linked hash map to guarantee iteration order based on order they were added. Used in inputs() and flatbuffers serde @Getter private final Map ops = new LinkedHashMap<>(); @@ -131,6 +131,8 @@ public class SameDiff extends SDBaseOps { private final List nameScopes = new ArrayList<>(); //Used as a stack + private List outputs; //Names of the output variables, set by the user. + /////////////////////////////////////// //Fields related to training @Getter @@ -141,15 +143,8 @@ public class SameDiff extends SDBaseOps { private Map updaterMap; //GradientUpdater instance for each trainable parameter //////////////////////////////////////// - //map a function's instance id to a base name, used for propagating variable names - //for output during import - private Map baseNameForFunctionInstanceId; private DifferentialFunctionFactory functionFactory; - @Deprecated //TO BE REMOVED - to ShapeSession - private Map variableNameToShape; //Key: SDVariable name. Value: shape for that variable - @Deprecated //TO BE REMOVED - to Variable - private Map forwardVarForGrad; // counter for auto-naming variables private int variableId = 0; @@ -246,38 +241,7 @@ public class SameDiff extends SDBaseOps { return bitwise; } - - /** - * For import, many times we have variables - * that map to properties. Most common - * we will have an input to a function that is mapped to an ndarray. - * That ndarray is usually a scalar shape. - *

- * That array with a scalar shape can be something like an axis. - *

- * We often don't know that array's value till run time. - * This map stores variable names that we should resolve - * from samediff. We use the value of that array - * to update the properties. - */ - private Map> propertiesToResolve; - - /** - * A map of own name to - * the properties of the function (things like execution axes etc) - * The valid values can be: - * int - * long - * INDArray - */ - private Map> propertiesForFunction; - - @Deprecated //TO BE REMOVED - to Variable - private Map placeHolderOriginalShapes; - private Map sameDiffFunctionDefinitionMap; private Map sameDiffFunctionInstances; - private Set placeHolderFunctions; - private static Map opMethods; private Table fieldVariableResolutionMapping; @@ -288,9 +252,6 @@ public class SameDiff extends SDBaseOps { //debug mode variables @Getter private boolean debugMode; - private Map opsForResult; - private boolean resolvedVariables = false; - @Getter private Stack argumentInterceptors = new Stack<>(); @@ -309,110 +270,6 @@ public class SameDiff extends SDBaseOps { @Getter private SameDiff child; - static { - opMethods = new HashMap<>(); - Method[] methods = SameDiff.class.getDeclaredMethods(); - for (Method method : methods) { - if (method.getReturnType().equals(SDVariable.class)) { - opMethods.put(method.getName(), method); - } - } - } - - - /** - * Update the opName for the variable with the given vertex id - * - * @param varName the vertex id to update - * @param withName thew new opName - */ - public void updateVariableName(String varName, String withName) { - SDVariable oldVarNameRef = getVariable(varName); - Variable v = variables.remove(varName); - String oldVarName = varName; - oldVarNameRef.setVarName(withName); - v.setName(withName); - variables.put(withName, v); - - for (SameDiffOp op : ops.values()) { - List outputsOfOp = op.getOutputsOfOp(); - if (outputsOfOp != null && !outputsOfOp.isEmpty()) { - for (int i = 0; i < outputsOfOp.size(); i++) { - if (outputsOfOp.get(i).equals(oldVarName)) { - outputsOfOp.set(i, withName); - } - } - } - - List inputsToOp = op.getInputsToOp(); - if (inputsToOp != null && !inputsToOp.isEmpty()) { - for (int i = 0; i < inputsToOp.size(); i++) { - if (inputsToOp.get(i).equals(oldVarName)) { - inputsToOp.set(i, withName); - } - } - } - } - -// if (variableNameToArr.containsKey(oldVarName)) { -// val arr = variableNameToArr.remove(oldVarName); -// variableNameToArr.put(withName, arr); -// } - - - if (variableNameToShape.containsKey(oldVarName)) { - val shape = variableNameToShape.remove(oldVarName); - variableNameToShape.put(withName, shape); - } - - if (forwardVarForGrad.containsKey(oldVarName)) { - val forwardGrad = forwardVarForGrad.remove(oldVarName); - forwardVarForGrad.put(withName, forwardGrad); - } - - - if (v.getInputsForOp() != null) { - List funcNames = v.getInputsForOp(); - for (String s : funcNames) { - DifferentialFunction func = ops.get(s).getOp(); - if (func instanceof BaseOp) { - BaseOp baseOp = (BaseOp) func; - if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) { - baseOp.setXVertexId(withName); - } - - if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) { - baseOp.setYVertexId(withName); - } - - if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(oldVarName)) { - baseOp.setZVertexId(withName); - } - - } - } - } - - - if (v.getOutputOfOp() != null) { - DifferentialFunction func = ops.get(v.getOutputOfOp()).getOp(); - if (func instanceof BaseOp) { - BaseOp baseOp = (BaseOp) func; - if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) { - baseOp.setXVertexId(withName); - } - - if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) { - baseOp.setYVertexId(withName); - } - - if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(oldVarName)) { - baseOp.setZVertexId(withName); - } - } - } - } - /** * Clears debugging state and disables debug mode. @@ -550,9 +407,9 @@ public class SameDiff extends SDBaseOps { * } * SDVariable z = sd.var("z", DataType.FLOAT, 5); * - * String xName = x.getVarName(); //RESULT: "x" - * String yName = y.getVarName(); //RESULT: "myScope/y" - * String zName = z.getVarName(); //RESULT: "z" + * String xName = x.name(); //RESULT: "x" + * String yName = y.name(); //RESULT: "myScope/y" + * String zName = z.name(); //RESULT: "z" * } * *

@@ -566,7 +423,7 @@ public class SameDiff extends SDBaseOps { * x = sd.var("x", DataType.FLOAT, 5); * } * } - * String xName = x.getVarName(); //RESULT: "first/second/x" + * String xName = x.name(); //RESULT: "first/second/x" * } * * @@ -605,7 +462,7 @@ public class SameDiff extends SDBaseOps { public List getVariablesInScope(NameScope scope) { ArrayList vars = new ArrayList<>(); for (SDVariable v : variables()) { - if (v.getVarName().startsWith(scope.getName())) + if (v.name().startsWith(scope.getName())) vars.add(v); } return vars; @@ -831,99 +688,6 @@ public class SameDiff extends SDBaseOps { } - /** - * Get the shape for the given vertex id. - * Note that if an array is defined, it will use the shape of the array instead. - *

- * A shape *and* an array should not be defined at the same time. - * This wastes memory. The internal map used for tracking shapes for particular - * vertex ids should also delete redundant shapes stored to avoid redundant sources of information. - * - * @param varName the vertex id to get the shape for - * @return the shape for the given vertex if any. - */ - public long[] getShapeForVarName(String varName) { - if (arrayAlreadyExistsForVarName(varName)) { - return getVariable(varName).getArr().shape(); - } - return variableNameToShape.get(varName); - } - - /** - * See {@link #getShapeForVarName(String)}, but returns the shape descriptor. - */ - public LongShapeDescriptor getShapeDescriptorForVarName(String varName) { - if (getVariable(varName).getArr() != null) { - return getVariable(varName).getArr().shapeDescriptor(); - } - // FIXME: do we really want this Nd4j.dataType() here? - return LongShapeDescriptor.fromShape(variableNameToShape.get(varName), Nd4j.dataType()); - } - - - /** - * Associate a vertex id with the given shape. - * - * @param varName the vertex id to associate - * @param shape the shape to associate with - * @see #putShapeForVarName(String, long[]) - * @see #putOrUpdateShapeForVarName(String, long[], boolean) - */ - @Deprecated - public void putShapeForVarName(String varName, long[] shape) { - if (shape == null) { - throw new ND4JIllegalStateException("Shape must not be null!"); - } - - if (variableNameToShape.containsKey(varName)) { - throw new ND4JIllegalStateException("Shape for " + varName + " already exists!"); - } - - variableNameToShape.put(varName, shape); - } - - - /** - * Sets the shape descriptor for a variable. - */ - public void putShapeForVarName(String varName, LongShapeDescriptor shape) { - val v = getVariable(varName); - putShapeForVarName(varName, shape.getShape()); - v.setDataType(shape.dataType()); - } - - /** - * Put or update the shape for the given variable name. Optionally supports clearing the specified variable's - * INDArray if it's shape does not match the new shape - * - * @param varName Variable name - * @param shape Shape to put - * @param clearArrayOnShapeMismatch If false: no change to arrays. If true: if an INDArray is defined for the specified - * variable name, it will be removed from the graph (to be later re-generated) if - * its shape does not match the specified shape - */ - @Deprecated - public void putOrUpdateShapeForVarName(String varName, long[] shape, boolean clearArrayOnShapeMismatch) { - Preconditions.checkNotNull(shape, "Cannot put null shape for variable: %s", varName); - if (variableNameToShape.containsKey(varName)) { -// updateShapeForVarName(varName, shape, clearArrayOnShapeMismatch); - //TODO - } else { - putShapeForVarName(varName, shape); - } - } - - /** - * Returns true if the given vertex id and shape already exist. - * - * @param varName the vertex id - * @return true if the ndarray and vertex id already exist - */ - public boolean shapeAlreadyExistsForVarName(String varName) { - return variableNameToShape.containsKey(varName) || arrayAlreadyExistsForVarName(varName); - } - - /** * Returns true if the given vertex id and {@link INDArray} already exist. * @@ -959,11 +723,6 @@ public class SameDiff extends SDBaseOps { SDVariable v = variables.get(varName).getVariable(); switch (v.getVariableType()) { case VARIABLE: - if (!variablesArrays.containsKey(varName)) { - //VARIBALE type arrays should have a parameter initializer... - // we should use this to azy init the array if none is present - v.storeAndAllocateNewArray(); - } return variablesArrays.get(varName).get(); case CONSTANT: if (!constantArrays.containsKey(varName)) @@ -1015,7 +774,7 @@ public class SameDiff extends SDBaseOps { arr = arr.castTo(variable.dataType()); Preconditions.checkState(variable.dataType() == arr.dataType(), "Variable \"%s\" has datatype %s: cannot associate array with type %s with this variable", - variable.getVarName(), variable.dataType(), arr.dataType()); + variable.name(), variable.dataType(), arr.dataType()); if (sessions.get(Thread.currentThread().getId()) == null) { sessions.put(Thread.currentThread().getId(), new InferenceSession(this)); @@ -1042,10 +801,10 @@ public class SameDiff extends SDBaseOps { switch (variable.getVariableType()) { case VARIABLE: - variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads + variablesArrays.put(variable.name(), new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads break; case CONSTANT: - constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true)); + constantArrays.put(variable.name(), new DeviceLocalNDArray(arr, true)); break; case ARRAY: throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for" + @@ -1062,19 +821,19 @@ public class SameDiff extends SDBaseOps { if (!placeholdersPerThread.containsKey(tid)) { placeholdersPerThread.put(tid, new HashMap()); } - placeholdersPerThread.get(tid).put(variable.getVarName(), arr); + placeholdersPerThread.get(tid).put(variable.name(), arr); break; default: throw new IllegalStateException("Unknown variable type: " + variable.getVariableType()); } - //putOrUpdateShapeForVarName(variable.getVarName(), arr.shape(), true); + //putOrUpdateShapeForVarName(variable.name(), arr.shape(), true); //Also update nested SameDiff instances (such as gradient function) if (sameDiffFunctionInstances != null && sameDiffFunctionInstances.size() > 0) { for (Map.Entry e : sameDiffFunctionInstances.entrySet()) { SameDiff sd = e.getValue(); - SDVariable v = sd.getVariable(variable.getVarName()); + SDVariable v = sd.getVariable(variable.name()); if (v != null) { sd.associateArrayWithVariable(arr, v); } @@ -1092,16 +851,16 @@ public class SameDiff extends SDBaseOps { */ public void assignArray(@NonNull INDArray arr, @NonNull SDVariable variable){ Preconditions.checkState(variable.getVariableType() == VariableType.VARIABLE || variable.getVariableType() == VariableType.CONSTANT, - "assignArray method can only be used with VARIBLE or CONSTANT type SDVariables, variable \"%s\" has type %s", variable.getVarName(), variable.getVariableType()); + "assignArray method can only be used with VARIBLE or CONSTANT type SDVariables, variable \"%s\" has type %s", variable.name(), variable.getVariableType()); //DeviceLocal doesn't work with views if(arr.isView()) arr = arr.dup(); if(variable.getVariableType() == VariableType.VARIABLE ){ - variablesArrays.get(variable.getVarName()).update(arr); + variablesArrays.get(variable.name()).update(arr); } else { - constantArrays.get(variable.getVarName()).update(arr); + constantArrays.get(variable.name()).update(arr); } } @@ -1134,38 +893,6 @@ public class SameDiff extends SDBaseOps { return ret; } - - /** - * Invoke an op by opName - * - * @param op the op - * @param x the first input - * @param y the second input - * @return the result variable - */ - @Deprecated //TO BE REMOVED - should not be part of public API - public SDVariable invoke(Op op, SDVariable x, SDVariable y) { - if (!opMethods.containsKey(op.opName())) { - throw new ND4JIllegalStateException("Illegal method opName " + op.opName()); - } - - if (x != null && y != null) { - try { - return (SDVariable) opMethods.get(op.opName()).invoke(this, x, y); - } catch (Exception e) { - - } - } else { - try { - return (SDVariable) opMethods.get(op.opName()).invoke(this, x); - } catch (Exception e) { - - } - } - - throw new ND4JIllegalStateException("Illegal method opName " + op.opName()); - } - /** * The set of defined SameDiff function names. SameDiff function instances should not be confused * with DifferentialFunction ops; an example of a SameDiff function instance is the gradient "grad" function @@ -1176,155 +903,10 @@ public class SameDiff extends SDBaseOps { return this.sameDiffFunctionInstances.keySet(); } - /** - * Invoke an op by opName - * - * @param op the op - * @param x the first input - * @return the result variable - */ - public SDVariable invoke(Op op, SDVariable x) { - return invoke(op, x, null); - } - private SameDiff() { functionFactory = new DifferentialFunctionFactory(this); - sameDiffFunctionDefinitionMap = new LinkedHashMap<>(); sameDiffFunctionInstances = new LinkedHashMap<>(); - forwardVarForGrad = new LinkedHashMap<>(); - opsForResult = new IntArrayKeyMap<>(); - variableNameToShape = new LinkedHashMap<>(); - placeHolderOriginalShapes = new LinkedHashMap<>(); - placeHolderFunctions = new LinkedHashSet<>(); - baseNameForFunctionInstanceId = new LinkedHashMap<>(); - propertiesToResolve = new LinkedHashMap<>(); - propertiesForFunction = new LinkedHashMap<>(); fieldVariableResolutionMapping = HashBasedTable.create(); - - } - - /** - * Adds a property that needs to be resolve for later. - * These variables are typically values that are arrays - * that are named but have an unknown value till execution time. - *

- * This is very common for model import. - * - * @param forFunction the function to add the property to resolve for - * @param arrayName the array name - */ - public void addPropertyToResolve(DifferentialFunction forFunction, String arrayName) { - if (!propertiesToResolve.containsKey(forFunction.getOwnName())) { - List newVal = new ArrayList<>(); - newVal.add(arrayName); - propertiesToResolve.put(forFunction.getOwnName(), newVal); - } else { - List newVal = propertiesToResolve.get(forFunction.getOwnName()); - newVal.add(arrayName); - } - } - - /** - * Remove a property to resolve added with {@link #addPropertyToResolve(DifferentialFunction, String)} - * - * @param forFunction the function to add the property to resolve for - * @param arrayName the array name - */ - public void removePropertyToResolve(DifferentialFunction forFunction, String arrayName) { - if (propertiesToResolve.containsKey(forFunction.getOwnName())) { - List newVal = propertiesToResolve.get(forFunction.getOwnName()); - newVal.remove(arrayName); - } - } - - /** - * Return the properties to resolve for the given function. - * This is typically used right before execution in model import in - * {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()} - * - * @param function the function get the properties to resolve for - * @return the properties to resolve for the given function - */ - public List propertiesToResolveForFunction(DifferentialFunction function) { - if (!propertiesToResolve.containsKey(function.getOwnName())) - return Collections.emptyList(); - - return propertiesToResolve.get(function.getOwnName()); - } - - - private void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, Object propertyValue) { - if (!propertiesForFunction.containsKey(functionFor.getOwnName())) { - Map fields = new LinkedHashMap<>(); - fields.put(propertyName, propertyValue); - propertiesForFunction.put(functionFor.getOwnName(), fields); - } else { - val fieldMap = propertiesForFunction.get(functionFor.getOwnName()); - if (fieldMap.containsKey(propertyName)) { - throw new ND4JIllegalStateException("Attempting to override property " + propertyName); - } - - fieldMap.put(propertyName, propertyValue); - } - } - - - /** - * Adds a field name -> variable name mapping for a given function.
- * This is used for model import where there is an unresolved variable at the time of calling any - * {@link org.nd4j.imports.graphmapper.GraphMapper#importGraph(File)} - * . - *

- * This data structure is typically accessed during {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()} - *

- * When a function attempts to resolve variables right before execution, there needs to be a way of knowing - * which variable in a samediff graph should map to a function's particular field name - * - * @param function the function to map - * @param fieldName the field name for the function to map - * @param varName the variable name of the array to get from samediff - */ - public void addVariableMappingForField(DifferentialFunction function, String fieldName, String varName) { - fieldVariableResolutionMapping.put(function.getOwnName(), fieldName, varName); - } - - /** - * Get the variable name to use - * for resolving a given field - * for a given function during import time. - * This method is u sed during {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()} - * - * @param function the function to get the variable name for - * @param fieldName the field name to resolve for - * @return the resolve variable name if any - */ - public String getVarNameForFieldAndFunction(DifferentialFunction function, String fieldName) { - return fieldVariableResolutionMapping.get(function.getOwnName(), fieldName); - } - - /** - * Sets a base name for the function id. - * This is used for when calling {@link #generateOutputVariableForOp(DifferentialFunction, String)} - * for ensuring original names for model import map to current samediff names - * when names are generated. - * - * @param baseName the base name to add - * @param function the function to declare a base name for. - */ - public void setBaseNameForFunctionInstanceId(String baseName, DifferentialFunction function) { - baseNameForFunctionInstanceId.put(function.getOwnName(), baseName); - } - - /** - * Returns the base name for the given function - * if any (may return null) - * - * @param function the function to get the base name for - * @return the base name for the given function (if any) based - * on the function's instance id. - */ - public String getBaseNameForFunction(DifferentialFunction function) { - return baseNameForFunctionInstanceId.get(function.getOwnName()); } @@ -1360,7 +942,7 @@ public class SameDiff extends SDBaseOps { public void addOutgoingFor(SDVariable[] variables, DifferentialFunction function) { String[] varNames = new String[variables.length]; for (int i = 0; i < varNames.length; i++) { - varNames[i] = variables[i].getVarName(); + varNames[i] = variables[i].name(); } addOutgoingFor(varNames, function); @@ -1499,7 +1081,7 @@ public class SameDiff extends SDBaseOps { if (interceptor != null) { pauseArgumentInterceptor(interceptor); for (int i = 0; i < variables.length; i++) { - variables[i] = interceptor.intercept(getVariable(variables[i])).getVarName(); + variables[i] = interceptor.intercept(getVariable(variables[i])).name(); } unpauseArgumentInterceptor(interceptor); } @@ -1507,13 +1089,6 @@ public class SameDiff extends SDBaseOps { if (function.getOwnName() == null) throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly"); - //double check if function contains placeholder args - for (val varName : variables) { - if (isPlaceHolder(varName)) { - placeHolderFunctions.add(function.getOwnName()); - } - } - //Add function if it doesn't exist //TODO could "not existing" be a bug sometimes? if (!ops.containsKey(function.getOwnName())) { @@ -1546,7 +1121,7 @@ public class SameDiff extends SDBaseOps { for (int i = 0; i < varNames.length; i++) { if (variables[i] == null) throw new ND4JIllegalStateException("Found null variable at index " + i); - varNames[i] = variables[i].getVarName(); + varNames[i] = variables[i].name(); } addArgsFor(varNames, function); } @@ -1561,25 +1136,8 @@ public class SameDiff extends SDBaseOps { function.getOwnName() + " only has " + function.args().length + " args but you are trying" + "to replace the argument at " + i); - String oldName = function.arg(i).getVarName(); - String newName = newArg.getVarName(); - - if (function.arg(i).isPlaceHolder() && !newArg.isPlaceHolder()) { - boolean otherPlaceholders = false; - for (int j = 0; j < function.argNames().length; j++) { - if (j == i) - continue; - - if (function.arg(j).isPlaceHolder()) - otherPlaceholders = true; - } - - if (!otherPlaceholders) - placeHolderFunctions.remove(function.getOwnName()); - } else if (!function.arg(i).isPlaceHolder() && newArg.isPlaceHolder()) { - if (!placeHolderFunctions.contains(function.getOwnName())) - placeHolderFunctions.add(function.getOwnName()); - } + String oldName = function.arg(i).name(); + String newName = newArg.name(); List oldArgs = ops.get(function.getOwnName()).getInputsToOp(); oldArgs = new ArrayList<>(oldArgs); @@ -1703,8 +1261,6 @@ public class SameDiff extends SDBaseOps { if (variables != null ? !variables.equals(sameDiff.variables) : sameDiff.variables != null) return false; - if (sameDiffFunctionDefinitionMap != null ? !sameDiffFunctionDefinitionMap.equals(sameDiff.sameDiffFunctionDefinitionMap) : sameDiff.sameDiffFunctionDefinitionMap != null) - return false; return sameDiffFunctionInstances != null ? sameDiffFunctionInstances.equals(sameDiff.sameDiffFunctionInstances) : sameDiff.sameDiffFunctionInstances == null; } @@ -1764,42 +1320,37 @@ public class SameDiff extends SDBaseOps { } /** - * Outputs are those variables (not placeholders, constants, etc) that are the output of a function that aren't the - * input to any other ops. - * Usually these are the output of the last function(s) in the SameDiff instance. + * Outputs are the names of the predictions of the network. + * Note that the outputs must be set using {@link #setOutputs(List)} first * - * @return The (inferred) outputs of the SameDiff instance, in no particular order + * @return The outputs of the SameDiff instance, or null if no outputs have been set */ public List outputs() { - List out = new ArrayList<>(); - for (Variable v : variables.values()) { - if (v.getVariable().isConstant() || v.getVariable().isPlaceHolder() || //Exclude constants and placeholders - (v.getInputsForOp() != null && !v.getInputsForOp().isEmpty()) || //Exclude variables that are inputs to ops - (v.getControlDepsForOp() != null && !v.getControlDepsForOp().isEmpty()) || //Exclude variables that are control dependency inputs to ops - (v.getControlDepsForVar() != null && !v.getControlDepsForVar().isEmpty())) { //Exclude variables that are control dependency inputs to other variables (mainly for import of cond etc ops) - continue; + return this.outputs; + } + + /** + * See {@link #setOutputs(List)} + */ + public void setOutputs(String... outputs){ + setOutputs(outputs == null ? null : Arrays.asList(outputs)); + } + + + /** + * Set the outputs of the SameDiff instance. + * Outputs are the names of the variables that are the predictions of the neural network. + * Note that this is merely a convenience, and does not impact execution at all. Outputs can be retrieved (after + * setting here) using {@link #outputs()} + * @param outputs Outputs to set. Must be valid variable names in this SameDiff instance + */ + public void setOutputs(List outputs){ + if(outputs != null){ + for(String s : outputs){ + Preconditions.checkArgument(variables.containsKey(s), "Cannot set variable \"%s\" as an output: SameDiff instance does not contain a variable with this name"); } - - //Also exclude assert etc ops - doesn't make sense to return these "outputs" to user - if (v.getOutputOfOp() != null) { - String opName = v.getOutputOfOp(); - SameDiffOp o = ops.get(opName); - if (o.getOp() instanceof Assert) { - continue; - } - - //A bit of a hack for TF import: some TF graphs have Switch ops, where the output of one branch isn't consumed - // by any ops. Consequently, during execution this "output" might never be available. So we'll exclude the output of execution here - // This applies to SameDiff while loops as well - if (o.getOp() instanceof Switch) { - continue; - } - } - - - out.add(v.getName()); } - return out; + this.outputs = outputs; } /** @@ -1845,7 +1396,7 @@ public class SameDiff extends SDBaseOps { public void setLossVariables(@NonNull SDVariable... lossVariables) { String[] varNames = new String[lossVariables.length]; for (int i = 0; i < lossVariables.length; i++) - varNames[i] = lossVariables[i].getVarName(); + varNames[i] = lossVariables[i].name(); setLossVariables(varNames); } @@ -1874,7 +1425,7 @@ public class SameDiff extends SDBaseOps { * See {@link #addLossVariable(String)} */ public void addLossVariable(@NonNull SDVariable variable) { - addLossVariable(variable.getVarName()); + addLossVariable(variable.name()); } /** @@ -2161,7 +1712,6 @@ public class SameDiff extends SDBaseOps { Map placeholders = toPlaceholderMap(ds); Preconditions.checkState(placeholders.size() > 0, "No placeholder variables were set for training"); - resolveVariablesWith(placeholders); //Call TrainingSession to perform training if (!initializedTraining) @@ -2825,7 +2375,7 @@ public class SameDiff extends SDBaseOps { * INDArray out = sd.output() * .data(data) * .output("pred") - * .execSingle(); + * .outputSingle(); * } * */ @@ -2861,7 +2411,7 @@ public class SameDiff extends SDBaseOps { if (outputs != null && outputs.length != 0) { neededOutputs = Arrays.asList(outputs); } else { - neededOutputs = outputs(); + neededOutputs = getLossVariables(); } String[] neededOutputsArr = neededOutputs.toArray(new String[0]); @@ -2931,7 +2481,7 @@ public class SameDiff extends SDBaseOps { * .output("out") * .input("x", xValue) * .input(y, yValue) - * .execSingle(); + * .outputSingle(); * } * */ @@ -2939,14 +2489,6 @@ public class SameDiff extends SDBaseOps { return new BatchOutputConfig(this); } - /** - * @deprecated See {@link #outputAll(Map)} and {@link #batchOutput()} - */ - @Deprecated - public Map execAll(Map placeholders) { - return outputAll(placeholders); - } - /** * Do inference for all variables for a single batch. *

@@ -2957,15 +2499,6 @@ public class SameDiff extends SDBaseOps { public Map outputAll(Map placeholders) { return batchOutput().outputAll().inputs(placeholders).exec(); } - - /** - * @deprecated See {@link #outputSingle(Map, String)} and {@link #batchOutput()} - */ - @Deprecated - public INDArray execSingle(Map placeholders, String output) { - return outputSingle(placeholders, output); - } - /** * Do inference for a single variable for a single batch. *

@@ -2977,14 +2510,6 @@ public class SameDiff extends SDBaseOps { return batchOutput().output(output).inputs(placeholders).execSingle(); } - /** - * @deprecated See {@link #output(Map, List)} and {@link #batchOutput()} - */ - @Deprecated - public Map exec(Map placeholders, List outputs) { - return output(placeholders, outputs); - } - /** * Do inference for the given variables for a single batch. *

@@ -2992,18 +2517,10 @@ public class SameDiff extends SDBaseOps { *

* Special case of {@link #batchOutput()}. */ - public Map output(Map placeholders, List outputs) { + public Map output(Map placeholders, @NonNull List outputs) { return batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).exec(); } - /** - * @deprecated See {@link #output(Map, String...)} and {@link #batchOutput()} - */ - @Deprecated - public Map exec(Map placeholders, String... outputs) { - return output(placeholders, outputs); - } - /** * Do inference for the given variables for a single batch. *

@@ -3084,7 +2601,7 @@ public class SameDiff extends SDBaseOps { /** * See {@link #one(String, DataType, int...)}. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). */ public SDVariable one(String name, int... shape) { @@ -3093,7 +2610,7 @@ public class SameDiff extends SDBaseOps { /** * See {@link #one(String, DataType, long...)}. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). */ public SDVariable one(String name, long... shape) { @@ -3103,31 +2620,31 @@ public class SameDiff extends SDBaseOps { /** * Create a new variable with the specified shape, with all values initialized to 1.0. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable one(String name, org.nd4j.linalg.api.buffer.DataType dataType, int... shape) { - return var(name, new ConstantInitScheme('f', 1.0), dataType, ArrayUtil.toLongArray(shape)); + return one(name, dataType, ArrayUtil.toLongArray(shape)); } /** * Create a new variable with the specified shape, with all values initialized to 1.0. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable one(String name, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { - return var(name, new ConstantInitScheme('f', 1.0), dataType, shape); + return constant(name, Nd4j.ones(dataType, shape)); } /** * See {@link #zero(String, DataType, long...)}. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). */ public SDVariable zero(String name, long... shape) { @@ -3136,7 +2653,7 @@ public class SameDiff extends SDBaseOps { /** * See {@link #zero(String, DataType, int...)}. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). */ public SDVariable zero(String name, int... shape) { @@ -3145,26 +2662,26 @@ public class SameDiff extends SDBaseOps { /** * Create a new variable with the specified shape, with all values initialized to 0. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable zero(String name, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { - return var(name, new ZeroInitScheme(), dataType, shape); + return constant(name, Nd4j.zeros(dataType, shape)); } /** * Create a new variable with the specified shape, with all values initialized to 0. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable zero(String name, org.nd4j.linalg.api.buffer.DataType dataType, int... shape) { - return var(name, new ZeroInitScheme(), dataType, ArrayUtil.toLongArray(shape)); + return zero(name, dataType, ArrayUtil.toLongArray(shape)); } /** @@ -3196,39 +2713,13 @@ public class SameDiff extends SDBaseOps { } } - SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null); - name = v.getVarName(); + SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType()); + name = v.name(); variables.put(name, Variable.builder().name(name).variable(v).build()); constantArrays.put(name, new DeviceLocalNDArray(constant, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads return v; } - /** - * Return a variable of given shape in which all values have a given constant value. - * - * @param value constant to set for each value - * @param shape shape of the variable as long array - * @return A new SDVariable of provided shape with constant value. - */ - @Deprecated - public SDVariable constant(SDVariable value, long... shape) { - return constant(null, value, shape); - } - - /** - * Return a variable of given shape in which all values have a given constant value. - * - * @param name Name of the new SDVariable - * @param value constant to set for each value - * @param shape shape of the variable as long array - * @return A new SDVariable of provided shape with constant value. - */ - @Deprecated - public SDVariable constant(String name, SDVariable value, long... shape) { - SDVariable ret = f().constant(value, shape); - return updateVariableNameAndReference(ret, name); - } - /** * Create a a placeholder variable. Placeholders are variables that expect an array to be provided during training * and inference.
@@ -3242,7 +2733,7 @@ public class SameDiff extends SDBaseOps { */ public SDVariable placeHolder(@NonNull String name, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { Preconditions.checkState(!variables.containsKey(name), "Variable already exists with name %s", name); - SDVariable ret = new SDVariable(name, VariableType.PLACEHOLDER, this, shape, dataType, null); + SDVariable ret = new SDVariable(name, VariableType.PLACEHOLDER, this, shape, dataType); variables.put(name, Variable.builder().name(name).variable(ret).build()); return ret; } @@ -3260,8 +2751,6 @@ public class SameDiff extends SDBaseOps { return var(name, VariableType.VARIABLE, weightInitScheme, dataType, shape); } - //TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API! - /** * Variable initialization with a specified {@link WeightInitScheme} * This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. @@ -3295,14 +2784,19 @@ public class SameDiff extends SDBaseOps { } } + Preconditions.checkState(variableType != VariableType.VARIABLE || weightInitScheme != null, "A weight initalization scheme must be provided" + + " when creating a VARIABLE type SDVariables - variable name: \"%s\"", name); - SDVariable ret = new SDVariable(name, variableType, this, shape, dataType, weightInitScheme); + SDVariable ret = new SDVariable(name, variableType, this, shape, dataType); addVariable(ret); - if (variableType == VariableType.PLACEHOLDER) { - setOriginalPlaceHolderShape(name, shape); - putShapeForVarName(name, shape); + if(variableType == VariableType.VARIABLE){ + try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + INDArray vArr = weightInitScheme.create(dataType, shape); + variablesArrays.put(name, new DeviceLocalNDArray(vArr, true)); + } } + return ret; } @@ -3418,25 +2912,29 @@ public class SameDiff extends SDBaseOps { * @return */ public SDVariable var(@NonNull final SDVariable v) { - if (variables.containsKey(v.getVarName()) && variables.get(v.getVarName()).getVariable().getArr() != null) - return variables.get(v.getVarName()).getVariable(); + if (variables.containsKey(v.name()) && variables.get(v.name()).getVariable().getArr() != null) + return variables.get(v.name()).getVariable(); - if (v.getVarName() == null || v.getVarName().length() < 1) + if (v.name() == null || v.name().length() < 1) throw new IllegalArgumentException("Name for variable must be defined"); VariableType vt = v.getVariableType(); NDArraySupplierInitScheme s = null; switch (vt) { case VARIABLE: - s = new NDArraySupplierInitScheme(v.getArr()); - //Intentional fallthrough + SDVariable r = new SDVariable(v.name(), v.getVariableType(), this, v.getShape(), v.dataType()); + addVariable(r); + try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ + variablesArrays.put(v.name(), new DeviceLocalNDArray(v.getArr().dup(), true)); + } + return r; case ARRAY: - SDVariable ret = new SDVariable(v.getVarName(), v.getVariableType(), this, v.getShape(), v.dataType(), s); + SDVariable ret = new SDVariable(v.name(), v.getVariableType(), this, v.getShape(), v.dataType()); return addVariable(ret); case CONSTANT: - return constant(v.getVarName(), v.getArr()); + return constant(v.name(), v.getArr()); case PLACEHOLDER: - return placeHolder(v.getVarName(), v.dataType(), v.placeholderShape()); + return placeHolder(v.name(), v.dataType(), v.placeholderShape()); default: throw new RuntimeException("Unknown/not supported variable type: " + vt); } @@ -3531,12 +3029,10 @@ public class SameDiff extends SDBaseOps { } } - SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType(), new NDArraySupplierInitScheme(arr)); + SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType()); associateArrayWithVariable(arr, ret); addVariable(ret); - if (getShapeForVarName(name) == null) - putShapeForVarName(name, arr.shape()); return ret; } @@ -3586,7 +3082,7 @@ public class SameDiff extends SDBaseOps { sameDiffFunctionInstances.remove(GRAD_FN_KEY); for (SDVariable variable : variables) { - String n = variable.getVarName(); + String n = variable.name(); INDArray arr = variable.getArr(); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); @@ -3605,7 +3101,7 @@ public class SameDiff extends SDBaseOps { if (trainingConfig != null && initializedTraining) { //Remove updater state for now constant variables for (SDVariable v : variables) { - GradientUpdater gu = updaterMap.remove(v.getVarName()); + GradientUpdater gu = updaterMap.remove(v.name()); Map m = gu == null ? null : gu.getState(); if (m != null) { for (INDArray arr : m.values()) { @@ -3615,27 +3111,27 @@ public class SameDiff extends SDBaseOps { } //Also check dataset feature/label mapping - remove any placeholders here... - if (trainingConfig.getDataSetFeatureMapping() != null && trainingConfig.getDataSetFeatureMapping().contains(v.getVarName())) { + if (trainingConfig.getDataSetFeatureMapping() != null && trainingConfig.getDataSetFeatureMapping().contains(v.name())) { List newFM = new ArrayList<>(trainingConfig.getDataSetFeatureMapping()); //New list in case of immutable list - newFM.remove(v.getVarName()); + newFM.remove(v.name()); trainingConfig.setDataSetFeatureMapping(newFM); } - if (trainingConfig.getDataSetLabelMapping() != null && trainingConfig.getDataSetLabelMapping().contains(v.getVarName())) { + if (trainingConfig.getDataSetLabelMapping() != null && trainingConfig.getDataSetLabelMapping().contains(v.name())) { List newLM = new ArrayList<>(trainingConfig.getDataSetLabelMapping()); - newLM.remove(v.getVarName()); + newLM.remove(v.name()); trainingConfig.setDataSetLabelMapping(newLM); } - if (trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().contains(v.getVarName())) { + if (trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().contains(v.name())) { List newFMM = new ArrayList<>(trainingConfig.getDataSetFeatureMaskMapping()); - newFMM.remove(v.getVarName()); + newFMM.remove(v.name()); trainingConfig.setDataSetFeatureMaskMapping(newFMM); } - if (trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().contains(v.getVarName())) { + if (trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().contains(v.name())) { List newLMM = new ArrayList<>(trainingConfig.getDataSetLabelMaskMapping()); - newLMM.remove(v.getVarName()); + newLMM.remove(v.name()); trainingConfig.setDataSetLabelMaskMapping(newLMM); } } @@ -3652,7 +3148,7 @@ public class SameDiff extends SDBaseOps { */ public SDVariable convertToVariable(@NonNull SDVariable constant) { Preconditions.checkState(constant.dataType().isFPType(), "Only floating point SDVariables can be converted to variables," + - " datatype of %s is %s", constant.getVarName(), constant.dataType()); + " datatype of %s is %s", constant.name(), constant.dataType()); convertToVariables(Collections.singletonList(constant)); return constant; } @@ -3684,7 +3180,7 @@ public class SameDiff extends SDBaseOps { sameDiffFunctionInstances.remove(GRAD_FN_KEY); for (SDVariable variable : constants) { - String n = variable.getVarName(); + String n = variable.name(); INDArray arr = variable.getArr(); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); @@ -3704,7 +3200,7 @@ public class SameDiff extends SDBaseOps { if (trainingConfig != null && initializedTraining) { //Add updater state for this variable: updaterState, updaterViews, updaterMap for (SDVariable v : constants) { - if (!updaterMap.containsKey(v.getVarName())) { + if (!updaterMap.containsKey(v.name())) { //Create new updater state INDArray arr = v.getArr(); long thisSize = trainingConfig.getUpdater().stateSize(arr.length()); @@ -3712,10 +3208,10 @@ public class SameDiff extends SDBaseOps { INDArray stateArr = Nd4j.create(arr.dataType(), 1, thisSize); GradientUpdater u = trainingConfig.getUpdater().instantiate(stateArr, false); u.setStateViewArray(stateArr, arr.shape(), arr.ordering(), true); //TODO eventually this should be 1 call... - updaterMap.put(v.getVarName(), u); + updaterMap.put(v.name(), u); } else { GradientUpdater u = trainingConfig.getUpdater().instantiate((INDArray) null, true); - updaterMap.put(v.getVarName(), u); + updaterMap.put(v.name(), u); } } } @@ -3914,6 +3410,26 @@ public class SameDiff extends SDBaseOps { variables.remove(from); variables.put(to, v); + if(v.getVariable().getVariableType() == VariableType.CONSTANT && constantArrays.containsKey(from)){ + DeviceLocalNDArray dl = constantArrays.remove(from); + constantArrays.put(to, dl); + } + + if(v.getVariable().getVariableType() == VariableType.VARIABLE && variablesArrays.containsKey(from)){ + DeviceLocalNDArray dl = variablesArrays.remove(from); + variablesArrays.put(to, dl); + } + + if(v.getVariable().getVariableType() == VariableType.PLACEHOLDER ){ + for(Map e : placeholdersPerThread.values()){ + //Not really thread safe - but renaming variables during execution in other threads can never be thread safe :) + if(e != null && e.containsKey(from)){ + INDArray arr = e.remove(from); + e.put(to, arr); + } + } + } + if (trainingConfig != null) { if (trainingConfig.getDataSetFeatureMapping() != null && trainingConfig.getDataSetFeatureMapping().contains(from)) { List l = new ArrayList<>(trainingConfig.getDataSetFeatureMapping()); @@ -3974,7 +3490,7 @@ public class SameDiff extends SDBaseOps { val args = function.args(); for (int i = 0; i < args.length; i++) { - if (args[i].getVarName().equals(varName)) { + if (args[i].name().equals(varName)) { /** * Since we are removing the variable reference * from the arguments we need to update both @@ -4075,15 +3591,6 @@ public class SameDiff extends SDBaseOps { variables.get(variableName).setGradient(variable); } - - /** - * @param varName - * @param forwardVariable - */ - public void setForwardVariableForVarName(String varName, SDVariable forwardVariable) { - forwardVarForGrad.put(varName, forwardVariable); - } - /** * Get the gradient for the variable with the specified variable name. * Note that in order to run this function, {@link #execBackwards(Map, Operation, MultiDataSet, Collection, List)} must be executed first. @@ -4094,12 +3601,12 @@ public class SameDiff extends SDBaseOps { */ public SDVariable grad(String varName) { if (!sameDiffFunctionInstances.containsKey(GRAD_FN_KEY)) { - throw new IllegalStateException("Unable to obtain gradient. Please run execBackwards() first."); + createGradFunction(); } SameDiff grad = getFunction(GRAD_FN_KEY); SDVariable var = grad.getVariable(varName); - return getFunction(GRAD_FN_KEY).getGradForVariable(var.getVarName()); + return getFunction(GRAD_FN_KEY).getGradForVariable(var.name()); } @@ -4281,12 +3788,12 @@ public class SameDiff extends SDBaseOps { public SDVariable addVariable(SDVariable variable) { Preconditions.checkState(variable.getSameDiff() == this, "Samediff instance must be the same."); - if (variables.containsKey(variable.getVarName()) && !variables.get(variable.getVarName()).getVariable().equals(variable)) { - throw new IllegalArgumentException("Variable with name \"" + variable.getVarName() + "\" already exists"); + if (variables.containsKey(variable.name()) && !variables.get(variable.name()).getVariable().equals(variable)) { + throw new IllegalArgumentException("Variable with name \"" + variable.name() + "\" already exists"); } Preconditions.checkState(variable.getSameDiff() == this, "Same diff instance for variable must be the same!"); - variables.put(variable.getVarName(), Variable.builder().name(variable.getVarName()).variable(variable).build()); + variables.put(variable.name(), Variable.builder().name(variable.name()).variable(variable).build()); return variable; } @@ -4299,11 +3806,6 @@ public class SameDiff extends SDBaseOps { * @return the set of names generated for each output of the function. */ public SDVariable[] generateOutputVariableForOp(DifferentialFunction function, String baseName, boolean isImport) { - //xyz ops only have 1 output - //if there is already a base name defined, use that - if (baseName == null || baseName.isEmpty() && getBaseNameForFunction(function) != null) - baseName = getBaseNameForFunction(function); - if (baseName == null) baseName = function.getOwnName(); @@ -4493,15 +3995,6 @@ public class SameDiff extends SDBaseOps { } - @Deprecated - public INDArray execAndEndResult() { - List outputs = outputs(); - Preconditions.checkState(outputs.size() == 1, "Method can only be used with SameDiff instances with a single output"); - long tid = Thread.currentThread().getId(); - Map placeholders = placeholdersPerThread.get(tid); - return execSingle(placeholders, outputs.get(0)); - } - /** * See {@link #calculateGradients(Map, Collection)} */ @@ -4529,7 +4022,7 @@ public class SameDiff extends SDBaseOps { SDVariable v = getVariable(s).getGradient(); if (v != null) { //In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable - gradVarNames.add(v.getVarName()); + gradVarNames.add(v.name()); } } @@ -4539,7 +4032,7 @@ public class SameDiff extends SDBaseOps { Map out = new HashMap<>(); for (String s : variables) { if (getVariable(s).getGradient() != null) { - String gradVar = getVariable(s).getGradient().getVarName(); + String gradVar = getVariable(s).getGradient().name(); out.put(s, grads.get(gradVar)); } } @@ -4547,136 +4040,6 @@ public class SameDiff extends SDBaseOps { return out; } - /** - * Create (if required) and then calculate the variable gradients (backward pass) for this graph.
- * After execution, the gradient arrays can be accessed using {@code myVariable.getGradient().getArr()}
- * Note: This method by default calculates VARIABLE type SDVariable gradients only (as well as any other - * gradients needed to calculate the variable gradients). That is, placeholder, constant, etc gradients are not - * calculated. If these gradients are required, they can be calculated using {@link #execBackwards(Map, List, Operation, MultiDataSet, Collection, List)} instead, - * which allows specifying the set of SDVariables to calculate the gradients for. For example, - * {@code execBackwards(placeholders, Arrays.asList(myPlaceholder.gradient().getVarName())}. In some cases, - * {@link #createGradFunction()} may need to be called first - * - * @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map - */ - @Deprecated - public void execBackwards(Map placeholders, Operation op) { - execBackwards(placeholders, op, null, Collections.emptyList(), Collections.emptyList()); - } - - /** - * See {@link #execBackwards(Map, Operation)}. - *

- * Uses {@link Operation#INFERENCE}. - */ - @Deprecated - public void execBackwards(Map placeholders) { - execBackwards(placeholders, Operation.INFERENCE); - } - - @Deprecated - protected void execBackwards(Map placeholders, Operation op, MultiDataSet batch, Collection requiredActivations, List activeListeners) { - if (getFunction(GRAD_FN_KEY) == null) { - createGradFunction(); - } - - //Collect (unique) list of gradient names... - Set varGradNames = new HashSet<>(); - for (Variable v : variables.values()) { - if (v.getVariable().getVariableType() == VariableType.VARIABLE) { - SDVariable g = v.getVariable().gradient(); - if (g != null) { - //Not all variables can have gradients... for example: suppose graph has 2 independent loss functions, - // optimizing only 1 might not require changing all variables - varGradNames.add(g.getVarName()); - } - } - } - - //Also add loss values - we need these so we can report them to listeners or loss curves... - if (!activeListeners.isEmpty() || op == Operation.TRAINING) { - varGradNames.addAll(lossVariables); - } - - //Edge case: if no variables, no variable gradients to calculate... - if (varGradNames.isEmpty()) { - log.warn("Skipping gradient execution (backward pass) - no variables to be calculated (graph does not contain any VARIABLE type SDVariables).\n" + - "If gradients for other variables (such as placeholders) are required, use execBackwards(Map, List) instead"); - } - - List vargradNamesList = new ArrayList<>(varGradNames); - execBackwards(placeholders, vargradNamesList, op, batch, requiredActivations, activeListeners); - } - - /** - * See {@link #execBackwards(Map, List, Operation)} - */ - @Deprecated - public Map execBackwards(Map placeholders, Operation op, String... variableGradNamesList) { - return execBackwards(placeholders, Arrays.asList(variableGradNamesList), op, null, Collections.emptyList(), Collections.emptyList()); - } - - /** - * See {@link #execBackwards(Map, Operation, String...)}. - *

- * Uses {@link Operation#INFERENCE}. - */ - @Deprecated - public Map execBackwards(Map placeholders, String... variableGradNamesList) { - return execBackwards(placeholders, Operation.INFERENCE, variableGradNamesList); - } - - /** - * As per {@link #execBackwards(Map, Operation, MultiDataSet, Collection, List)}, but the set of gradients to calculate can be specified manually.
- * For example, to calculate the gradient for placeholder variable "myPlaceholder", use - * {@code execBackwards(placeholders, Arrays.asList(myPlaceholder.gradient().getVarName())}. - * - * @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map - * @param variableGradNamesList Names of the gradient variables to calculate - */ - @Deprecated - public Map execBackwards(Map placeholders, List variableGradNamesList, Operation operation) { - return execBackwards(placeholders, variableGradNamesList, operation, null, Collections.emptyList(), Collections.emptyList()); - } - - /** - * See {@link #execBackwards(Map, List, Operation)}. - *

- * Uses {@link Operation#INFERENCE}. - */ - @Deprecated - public Map execBackwards(Map placeholders, List variableGradNamesList) { - return execBackwards(placeholders, variableGradNamesList, Operation.INFERENCE); - } - - @Deprecated - protected Map execBackwards(Map placeholders, List variableGradNamesList, Operation operation, - MultiDataSet batch, Collection requiredActivations, List activeListeners) { - if (getFunction(GRAD_FN_KEY) == null) { - createGradFunction(); - } - - log.trace("About to execute backward function"); - - //Edge case: if no variables, no variable gradients to calculate... - if (variableGradNamesList.isEmpty()) { - log.warn("Skipping gradient calculation (backward pass) - no variables to be calculated (variableGradNamesList is empty)"); - return Collections.emptyMap(); - } - - SameDiff sd = sameDiffFunctionInstances.get(GRAD_FN_KEY); - sd.listeners.clear(); - sd.listeners.addAll(activeListeners); - - At at = new At(0, 0, 0, Thread.currentThread().getId(), operation); - if (trainingConfig != null) { - at.setIteration(trainingConfig.getIterationCount()); - at.setEpoch(trainingConfig.getEpochCount()); - } - - return sd.directExecHelper(placeholders, at, batch, requiredActivations, activeListeners, variableGradNamesList.toArray(new String[0])); - } - /** * Returns true if the gradient function has been created - i.e., {@link #createGradFunction()} or {@link #createGradFunction(String...)} * has been called at all @@ -4688,7 +4051,7 @@ public class SameDiff extends SDBaseOps { } /** - * Create the gradient function (for calculating gradients via {@link #execBackwards(Map, Operation, String[])}) if it is not already defined. + * Create the gradient function (for calculating gradients via {@link #calculateGradients(Map, Collection)}) if it is not already defined. * Users do not usually need to call this function manually, as it is called as required in the aforementioned method. *

* If the gradient function already exists, this method is a no-op.
@@ -4715,14 +4078,23 @@ public class SameDiff extends SDBaseOps { if (trainingConfig != null && trainingConfig.getLossVariables() != null && !trainingConfig.getLossVariables().isEmpty()) { lossVariables.addAll(trainingConfig.getLossVariables()); } else { - List outputs = outputs(); - if (outputs.size() == 1) { - String outName = outputs.get(0); + List lossInferred = bestGuessLossVariables(); + if (lossInferred.size() == 1) { + String outName = lossInferred.get(0); String opName = variables.get(outName).getOutputOfOp(); if (opName == null || !(ops.get(opName).getOp() instanceof ExternalErrorsFunction)) { - log.info("Inferring output \"{}\" as loss variable as none were previously set. Use SameDiff.setLossVariables() to override", outputs.get(0)); + log.info("Inferring output \"{}\" as loss variable as none were previously set." + + "Use SameDiff.setLossVariables() or SDVariable.markAsLoss() to override", lossInferred.get(0)); + } + lossVariables.add(lossInferred.get(0)); + } else if(lossInferred.isEmpty()){ + //Check for external errors function + for(SameDiffOp o : ops.values()){ + if(o.getOp() instanceof ExternalErrorsFunction){ + List l = o.getOutputsOfOp(); + lossVariables.add(l.get(0)); + } } - lossVariables.add(outputs.get(0)); } } } @@ -4824,9 +4196,9 @@ public class SameDiff extends SDBaseOps { "point variable (datatype: %s). Only floating point variables may be used as loss function variable", s, v.dataType()); v = v.sum(); //If output is not a scalar: we'll use loss = v.sum(), same as adding loss for multiple outputs. We don't always know for sure if output is scalar at this point if (v.dataType() == initialGrad.dataType()) { - sameDiff.setGradientForVariableName(v.getVarName(), initialGrad); + sameDiff.setGradientForVariableName(v.name(), initialGrad); } else { - sameDiff.setGradientForVariableName(v.getVarName(), initialGrad.castTo(v.dataType())); + sameDiff.setGradientForVariableName(v.name(), initialGrad.castTo(v.dataType())); } if (finalOutputs.contains(v)) { log.warn("Loss function variable \"{}\" appears multiple times in list of loss variables - using only first instance", s); @@ -4954,7 +4326,7 @@ public class SameDiff extends SDBaseOps { //At this point: we know the set of variables that are connected to the loss - these all (and only) need gradients Queue availableForDiff = new LinkedList<>(); for (SDVariable lossVar : finalOutputs) { - Variable v = sameDiff.variables.get(lossVar.getVarName()); + Variable v = sameDiff.variables.get(lossVar.name()); if (v.getOutputOfOp() != null) { String opName = v.getOutputOfOp(); availableForDiff.add(opName); @@ -5136,52 +4508,39 @@ public class SameDiff extends SDBaseOps { associateSameDiffWithOpsAndVariables(); } - /** - * Set the original shape for a given place holder.
- * This is used to track original shapes of place holder variables.
- * The reason we track original shapes is to validate possible candidate arrays coming in (especially with -1 - * as the expected shapes). - *

- * Note that if {@link #isPlaceHolder(String)} - * returns false for the passed in vertex id, - * a {@link ND4JIllegalStateException} is thrown. - *

- * - * @param variableName the vertex id for the original shape - * @param shape the shape of the place holder + * Try to infer the loss variable/s (usually loss variables). Note that this is not reliable in general. */ - public void setOriginalPlaceHolderShape(String variableName, @NonNull long... shape) { - if (!isPlaceHolder(variableName)) { - throw new ND4JIllegalStateException("Vertex id " + variableName + " does not appear to be a place holder. Did you forget to call addPlaceHolder?"); + protected List bestGuessLossVariables() { + List out = new ArrayList<>(); + for (Variable v : variables.values()) { + if (v.getVariable().isConstant() || v.getVariable().isPlaceHolder() || //Exclude constants and placeholders + (v.getInputsForOp() != null && !v.getInputsForOp().isEmpty()) || //Exclude variables that are inputs to ops + (v.getControlDepsForOp() != null && !v.getControlDepsForOp().isEmpty()) || //Exclude variables that are control dependency inputs to ops + (v.getControlDepsForVar() != null && !v.getControlDepsForVar().isEmpty())) { //Exclude variables that are control dependency inputs to other variables (mainly for import of cond etc ops) + continue; + } + + //Also exclude assert etc ops - doesn't make sense to return these "outputs" to user + if (v.getOutputOfOp() != null) { + String opName = v.getOutputOfOp(); + SameDiffOp o = ops.get(opName); + if (o.getOp() instanceof Assert) { + continue; + } + + //A bit of a hack for TF import: some TF graphs have Switch ops, where the output of one branch isn't consumed + // by any ops. Consequently, during execution this "output" might never be available. So we'll exclude the output of execution here + // This applies to SameDiff while loops as well + if (o.getOp() instanceof Switch) { + continue; + } + } + + + out.add(v.getName()); } - - if (shape == null) { - throw new ND4JIllegalStateException("Null and 0 length shape arrays not allowed"); - } - - - if (placeHolderOriginalShapes.containsKey(variableName) && !Arrays.equals(placeHolderOriginalShapes.get(variableName), shape)) { - throw new ND4JIllegalStateException("Unable to add a new shape for vertex id " + variableName); - } - - //after validation now only set once - placeHolderOriginalShapes.put(variableName, shape); - - } - - - /** - * Get the original shape for the vertex id if one was set (other wise returns null).
- * This is mainly for use in validating passed in arrays as arguments to {@link #resolveVariablesWith(Map)} - * usually when executing using {@link #execAll(Map)} - * - * @param varName the vertex id to get the original shape for. - * @return the set vertex - */ - @Deprecated - public long[] getOriginalShapeForPlaceHolder(String varName) { - return placeHolderOriginalShapes.get(varName); + return out; } /** @@ -5196,53 +4555,6 @@ public class SameDiff extends SDBaseOps { return variables.get(varName).getVariable().isPlaceHolder(); } - - /** - * Resolve all ndarrays by updating the variables for each array specified in the given map. - * An {@link IllegalStateException} will be thrown if not all arrays are specified for resolution. - * - * @param arrays the arrays to resolve. - */ - public void resolveVariablesWith(Map arrays) { - for (Map.Entry e : arrays.entrySet()) { - SDVariable varForName = getVariable(e.getKey()); - if (varForName == null) { - throw new ND4JIllegalStateException("A placeholder array was provided for variable with name \"" + e.getKey() + - "\" but no variable with this name exists"); - } - - Variable v = variables.get(e.getKey()); - if (varForName.getVariableType() == VariableType.PLACEHOLDER) { - //Check shape: - long[] shape = varForName.placeholderShape(); - long[] newShape = e.getValue().shape(); - Preconditions.checkState(shape.length == newShape.length, "Placeholder shape not compatible (mismatched rank): placeholder \"%s\" " + - "shape %s, got incompatible shape %s", e.getKey(), shape, newShape); - } - } - - - for (val entry : arrays.entrySet()) { - if (!variables.get(entry.getKey()).getVariable().isPlaceHolder()) { - throw new ND4JIllegalStateException("Illegal variable " + entry.getKey() + " passed in. Variable found not to be a place holder variable"); - } - - val specifiedShape = getOriginalShapeForPlaceHolder(entry.getKey()); - //whole shape was specified: validate whether the input array shape is equal - if (!Shape.isPlaceholderShape(specifiedShape)) { - if (!Shape.shapeEquals(specifiedShape, entry.getValue().shape())) { - throw new ND4JIllegalStateException("Place holder shape specified was " + Arrays.toString(specifiedShape) + " but array shape was " + Arrays.toString(entry.getValue().shape())); - } - } - - associateArrayWithVariable(entry.getValue(), getVariable(entry.getKey())); - setArrayForVariable(entry.getKey(), entry.getValue()); - } - - //declare resolved - resolvedVariables = true; - } - /** * Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable. *

@@ -5270,20 +4582,20 @@ public class SameDiff extends SDBaseOps { throw new IllegalStateException("Variable name \"" + newVarName + "\" already exists for a different SDVariable"); } - if (newVarName == null && variables.containsKey(varToUpdate.getVarName()) - && variables.get(varToUpdate.getVarName()).getVariable() != varToUpdate) { + if (newVarName == null && variables.containsKey(varToUpdate.name()) + && variables.get(varToUpdate.name()).getVariable() != varToUpdate) { //Edge case: suppose we do m1=sd.mean(in), m2=sd.mean(m1) -> both initially have the name // "mean" and consequently a new variable name needs to be generated - newVarName = generateNewVarName(varToUpdate.getVarName(), 0); + newVarName = generateNewVarName(varToUpdate.name(), 0); } - if (newVarName == null || varToUpdate.getVarName().equals(newVarName)) { + if (newVarName == null || varToUpdate.name().equals(newVarName)) { return varToUpdate; } - val oldVarName = varToUpdate.getVarName(); + val oldVarName = varToUpdate.name(); varToUpdate.setVarName(newVarName); - updateVariableName(oldVarName, newVarName); + renameVariable(oldVarName, newVarName); return varToUpdate; } @@ -5446,11 +4758,11 @@ public class SameDiff extends SDBaseOps { List allVars = variables(); for (SDVariable variable : allVars) { INDArray arr = variable.getVariableType() == VariableType.ARRAY ? null : variable.getArr(); - log.trace("Exporting variable: [{}]", variable.getVarName()); + log.trace("Exporting variable: [{}]", variable.name()); //If variable is the output of some op - let's use the ONE index for exporting, and properly track the output // numbers. For example, unstack(x) -> y0, y1, y2 -> the y's should be say (3,0), (3,1), (3,2) NOT (4,0), (5,0), (6,0) - String varName = variable.getVarName(); + String varName = variable.name(); int varIdx; int outputNum; if (variables.get(varName).getOutputOfOp() != null) { @@ -5471,11 +4783,11 @@ public class SameDiff extends SDBaseOps { } - reverseMap.put(variable.getVarName(), varIdx); + reverseMap.put(variable.name(), varIdx); - log.trace("Adding [{}] as [{}]", variable.getVarName(), varIdx); + log.trace("Adding [{}] as [{}]", variable.name(), varIdx); int shape = 0; - int name = bufferBuilder.createString(variable.getVarName()); + int name = bufferBuilder.createString(variable.name()); int array = 0; int id = IntPair.createIntPair(bufferBuilder, varIdx, outputNum); byte varType = (byte) variable.getVariableType().ordinal(); @@ -5486,7 +4798,10 @@ public class SameDiff extends SDBaseOps { if (variable.getVariableType() == VariableType.PLACEHOLDER) { val shp = variable.getShape(); - shape = FlatVariable.createShapeVector(bufferBuilder, shp); + if(shp != null) { + //Some models may have no shape defined, not ever a placeholder type shape + shape = FlatVariable.createShapeVector(bufferBuilder, shp); + } } int controlDeps = 0; @@ -5536,7 +4851,7 @@ public class SameDiff extends SDBaseOps { for (SDVariable v : variables()) { if (!v.isPlaceHolder()) continue; - placeholderOffsets[i++] = bufferBuilder.createString(v.getVarName()); + placeholderOffsets[i++] = bufferBuilder.createString(v.name()); } } int placeholdersOffset = FlatGraph.createPlaceholdersVector(bufferBuilder, placeholderOffsets); @@ -5881,9 +5196,8 @@ public class SameDiff extends SDBaseOps { //TODO Infer this properly! Could be constant, etc. VariableType vt = VariableType.values()[v.variabletype()]; - SDVariable var = new SDVariable(n, vt, sd, shape, dtype, null); + SDVariable var = new SDVariable(n, vt, sd, shape, dtype); sd.variables.put(n, Variable.builder().name(n).variable(var).build()); - sd.variableNameToShape.put(n, shape); Variable v2 = sd.variables.get(n); //Reconstruct control dependencies @@ -5979,7 +5293,7 @@ public class SameDiff extends SDBaseOps { if (varIn == null) { //The variable corresponding to this op was not } - inputNames[i] = varIn.getVarName(); + inputNames[i] = varIn.name(); } SameDiffOp op = sd.ops.get(df.getOwnName()); op.setInputsToOp(Arrays.asList(inputNames)); @@ -6038,7 +5352,7 @@ public class SameDiff extends SDBaseOps { if (varsForOp != null && varsForOp.size() == numOutputs) { varNames = new String[varsForOp.size()]; for (int i = 0; i < varNames.length; i++) { - varNames[i] = varsForOp.get(i).getVarName(); + varNames[i] = varsForOp.get(i).name(); sd.getVariables().get(varNames[i]).setOutputOfOp(df.getOwnName()); } sd.ops.get(df.getOwnName()).setOutputsOfOp(Arrays.asList(varNames)); @@ -6051,7 +5365,7 @@ public class SameDiff extends SDBaseOps { varNames[i] = n; if (!sd.variables.containsKey(n)) { //Need to create the variable - perhaps it wasn't exported. Note output of node -> can only be VARIABLE type - SDVariable var = new SDVariable(n, VariableType.VARIABLE, sd, null, null, null); + SDVariable var = new SDVariable(n, VariableType.VARIABLE, sd, null, null); sd.variables.put(n, Variable.builder().name(n).variable(var).build()); variablesByNodeAndOutNum.put(new Pair<>(opId, i), var); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java index 353c2d1e1..d50daddb8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java @@ -440,7 +440,7 @@ public class TrainingConfig { * @param evaluations The evaluations to run */ public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){ - return trainEvaluation(variable.getVarName(), labelIndex, evaluations); + return trainEvaluation(variable.name(), labelIndex, evaluations); } /** @@ -468,7 +468,7 @@ public class TrainingConfig { * @param evaluations The evaluations to run */ public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){ - return validationEvaluation(variable.getVarName(), labelIndex, evaluations); + return validationEvaluation(variable.name(), labelIndex, evaluations); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/BatchOutputConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/BatchOutputConfig.java index ba5aaa234..ffa0c5f85 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/BatchOutputConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/BatchOutputConfig.java @@ -73,7 +73,7 @@ public class BatchOutputConfig { public BatchOutputConfig output(@NonNull SDVariable... outputs){ String[] outNames = new String[outputs.length]; for(int i = 0 ; i < outputs.length ; i++){ - outNames[i] = outputs[i].getVarName(); + outNames[i] = outputs[i].name(); } return output(outNames); @@ -104,7 +104,7 @@ public class BatchOutputConfig { * See {@link #input(String, INDArray)} */ public BatchOutputConfig input(@NonNull SDVariable variable, @NonNull INDArray placeholder){ - return input(variable.getVarName(), placeholder); + return input(variable.name(), placeholder); } /** @@ -132,19 +132,35 @@ public class BatchOutputConfig { return this; } + /** + * @deprecated Use {@link #output()} + */ + @Deprecated + public Map exec() { + return output(); + } + /** * Do inference and return the results */ - public Map exec(){ + public Map output(){ return sd.output(placeholders, listeners, outputs.toArray(new String[0])); } + /** + * @deprecated Use {@link #outputSingle()} + */ + @Deprecated + public INDArray execSingle() { + return outputSingle(); + } + /** * Do inference and return the results for the single output * * Only works if exactly one output is specified */ - public INDArray execSingle(){ + public INDArray outputSingle(){ Preconditions.checkState(outputs.size() == 1, "Can only use execSingle() when exactly one output is specified, there were %s", outputs.size()); return exec().get(outputs.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/EvaluationConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/EvaluationConfig.java index f4477adad..389e14306 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/EvaluationConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/EvaluationConfig.java @@ -81,7 +81,7 @@ public class EvaluationConfig { * See {@link #evaluate(String, int, IEvaluation[])} */ public EvaluationConfig evaluate(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){ - return evaluate(variable.getVarName(), labelIndex, evaluations); + return evaluate(variable.name(), labelIndex, evaluations); } /** @@ -106,7 +106,7 @@ public class EvaluationConfig { * See {@link #evaluate(String, IEvaluation[])} */ public EvaluationConfig evaluate(@NonNull SDVariable variable, @NonNull IEvaluation... evaluations){ - return evaluate(variable.getVarName(), evaluations); + return evaluate(variable.name(), evaluations); } /** @@ -129,7 +129,7 @@ public class EvaluationConfig { * See {@link #labelIndex(String, int)} */ public EvaluationConfig labelIndex(@NonNull SDVariable variable, int labelIndex){ - return labelIndex(variable.getVarName(), labelIndex); + return labelIndex(variable.name(), labelIndex); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java index 376430f54..5cafa09aa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java @@ -75,7 +75,7 @@ public class OutputConfig { public OutputConfig output(@NonNull SDVariable... outputs) { String[] outNames = new String[outputs.length]; for (int i = 0; i < outputs.length; i++) { - outNames[i] = outputs[i].getVarName(); + outNames[i] = outputs[i].name(); } return output(outNames); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index cbdb39cd6..1f93dbe94 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -204,10 +204,10 @@ public abstract class AbstractSession { VariableType vt = v.getVariableType(); if (vt == VariableType.VARIABLE || vt == VariableType.CONSTANT) { ExecType et = vt == VariableType.VARIABLE ? ExecType.VARIABLE : ExecType.CONSTANT; - ExecStep es = new ExecStep(et, v.getVarName(), new FrameIter(OUTER_FRAME, 0, null)); + ExecStep es = new ExecStep(et, v.name(), new FrameIter(OUTER_FRAME, 0, null)); dt.addDependency(es, start); - Variable var = sameDiff.getVariables().get(v.getVarName()); + Variable var = sameDiff.getVariables().get(v.name()); if (var.getControlDeps() != null) { addVarControlDeps(es, var); //Before this variable can be considered available for use, we need specified op to be executed } @@ -668,11 +668,11 @@ public abstract class AbstractSession { Variable v = sameDiff.getVariables().get(varName); VariableType vt = v.getVariable().getVariableType(); if (vt == VariableType.VARIABLE) { - return new ExecStep(ExecType.VARIABLE, v.getVariable().getVarName(), new FrameIter(OUTER_FRAME, 0, null)); + return new ExecStep(ExecType.VARIABLE, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null)); } else if (vt == VariableType.PLACEHOLDER) { - return new ExecStep(ExecType.PLACEHOLDER, v.getVariable().getVarName(), new FrameIter(OUTER_FRAME, 0, null)); + return new ExecStep(ExecType.PLACEHOLDER, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null)); } else if (vt == VariableType.CONSTANT) { - return new ExecStep(ExecType.CONSTANT, v.getVariable().getVarName(), new FrameIter(OUTER_FRAME, 0, null)); + return new ExecStep(ExecType.CONSTANT, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null)); } else { //Array type. Must be output of an op String outOfOp = v.getOutputOfOp(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index 354f537a8..fd89b4653 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -98,9 +98,9 @@ public class InferenceSession extends AbstractSession { //TODO we shouldn't be clearing this on every single iteration, in 99.5% of cases variables will be same as last iteration... for (SDVariable v : sameDiff.variables()) { if (v.getVariableType() == VariableType.CONSTANT) { - arrayUseTracker.addDependency(v.getArr(), new ConstantDep(v.getVarName())); + arrayUseTracker.addDependency(v.getArr(), new ConstantDep(v.name())); } else if (v.getVariableType() == VariableType.VARIABLE) { - arrayUseTracker.addDependency(v.getArr(), new VariableDep(v.getVarName())); + arrayUseTracker.addDependency(v.getArr(), new VariableDep(v.name())); } } @@ -484,7 +484,7 @@ public class InferenceSession extends AbstractSession { if (op instanceof TensorArray) { //Create a TensorArray - VarId vid = outputFrameIter.toVarId(op.outputVariable().getVarName()); + VarId vid = outputFrameIter.toVarId(op.outputVariable().name()); Preconditions.checkState(!tensorArrays.containsKey(vid), "TensorArray already exists for %s when executing TensorArrayV3", vid); tensorArrays.put(vid, new ArrayList()); @@ -504,18 +504,18 @@ public class InferenceSession extends AbstractSession { SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array //Work out the frame/iteration: - VarId v = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); + VarId v = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false)); if (v == null && allIterInputs != null) { - v = lookup(inTensorArray.getVarName(), allIterInputs, false); + v = lookup(inTensorArray.name(), allIterInputs, false); } - Preconditions.checkState(v != null, "Could not find input %s", inTensorArray.getVarName()); + Preconditions.checkState(v != null, "Could not find input %s", inTensorArray.name()); - while (sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter) { + while (sameDiff.getVariableOutputOp(inTensorArray.name()) instanceof Enter) { //Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayRead //TODO also TensorArrayWrite, scatter, etc?? - inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg(); - v = v.getParentFrame().toVarId(inTensorArray.getVarName()); + inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.name()).arg(); + v = v.getParentFrame().toVarId(inTensorArray.name()); } List list = getTensorArrays().get(v); @@ -528,31 +528,31 @@ public class InferenceSession extends AbstractSession { //TensorArrayWrite - also has a scalar 0.0 that it returns... SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array //Work out the varid (frame/iteration) of the tensor array: - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false)); if (tArr == null && allIterInputs != null) { - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); + tArr = lookup(inTensorArray.name(), allIterInputs, false); } - Preconditions.checkState(tArr != null, "Could not find input %s", inTensorArray.getVarName()); + Preconditions.checkState(tArr != null, "Could not find input %s", inTensorArray.name()); - while (sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter) { + while (sameDiff.getVariableOutputOp(inTensorArray.name()) instanceof Enter) { //Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayWrite //TODO also TensorArrayScatter, etc?? - inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg(); - tArr = tArr.getParentFrame().toVarId(inTensorArray.getVarName()); + inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.name()).arg(); + tArr = tArr.getParentFrame().toVarId(inTensorArray.name()); } //Input 0 is the TensorArray (or dummy variable that represents it) - but sometimes Enter, in TensorArray -> Enter -> TensorARrayRead //Input 1 is the index //Input 2 is the value to write - String idxName = op.arg(1).getVarName(); + String idxName = op.arg(1).name(); SDVariable idxSDV = sameDiff.getVariable(idxName); INDArray idxArr = getArray(idxSDV, opInputs, allIterInputs); Preconditions.checkState(idxArr.isScalar(), "Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", idxArr); int idx = idxArr.getInt(0); - String inName = op.arg(2).getVarName(); + String inName = op.arg(2).name(); SDVariable inSDV = sameDiff.getVariable(inName); INDArray arr = getArray(inSDV, opInputs, allIterInputs); Preconditions.checkState(arr != null, "Could not find array for %s", inName); @@ -577,9 +577,9 @@ public class InferenceSession extends AbstractSession { //Index 0 is the TensorArray (or dummy variable that represents it) SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array //Work out the varid (frame/iteration) of the tensor array: - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false)); if (tArr == null && allIterInputs != null) { - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); + tArr = lookup(inTensorArray.name(), allIterInputs, false); } List l = tensorArrays.get(tArr); Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); @@ -588,9 +588,9 @@ public class InferenceSession extends AbstractSession { return new INDArray[]{scalar}; } else if (op instanceof TensorArrayConcat) { SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false)); if (tArr == null && allIterInputs != null) { - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); + tArr = lookup(inTensorArray.name(), allIterInputs, false); } List l = tensorArrays.get(tArr); @@ -605,14 +605,14 @@ public class InferenceSession extends AbstractSession { //Input 1: the indices (1d integer vector) SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false)); if (tArr == null && allIterInputs != null) { - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); + tArr = lookup(inTensorArray.name(), allIterInputs, false); } List l = tensorArrays.get(tArr); Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); - String indicesName = op.arg(1).getVarName(); + String indicesName = op.arg(1).name(); SDVariable indicesSDV = sameDiff.getVariable(indicesName); INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs); Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", idxArr, indicesName); @@ -644,22 +644,22 @@ public class InferenceSession extends AbstractSession { //Input 2: The values to scatter SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - TensorArray ta = (TensorArray) sameDiff.getVariableOutputOp(inTensorArray.getVarName()); - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); + TensorArray ta = (TensorArray) sameDiff.getVariableOutputOp(inTensorArray.name()); + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false)); if (tArr == null && allIterInputs != null) { - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); + tArr = lookup(inTensorArray.name(), allIterInputs, false); } List l = tensorArrays.get(tArr); Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); - String indicesName = op.arg(1).getVarName(); + String indicesName = op.arg(1).name(); SDVariable indicesSDV = sameDiff.getVariable(indicesName); INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs); Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", idxArr, indicesName); Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", idxArr.dataType(), indicesName); int[] idxs = idxArr.toIntVector(); - String valuesName = op.arg(2).getVarName(); + String valuesName = op.arg(2).name(); SDVariable valuesSDV = sameDiff.getVariable(valuesName); INDArray valuesArr = getArray(valuesSDV, opInputs, allIterInputs); @@ -697,18 +697,18 @@ public class InferenceSession extends AbstractSession { //Input 2: the size of each split (1d integer vector) SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false)); if (tArr == null && allIterInputs != null) { - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); + tArr = lookup(inTensorArray.name(), allIterInputs, false); } List l = tensorArrays.get(tArr); Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); - String splitName = op.arg(1).getVarName(); + String splitName = op.arg(1).name(); INDArray splitArr = getArray(sameDiff.getVariable(splitName), opInputs, allIterInputs); - String sizeName = op.arg(2).getVarName(); + String sizeName = op.arg(2).name(); SDVariable sizeSDV = sameDiff.getVariable(sizeName); INDArray sizeArr = getArray(sizeSDV, opInputs, allIterInputs); Preconditions.checkState(sizeArr.isVector(), "Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", sizeArr, sizeName); @@ -803,7 +803,7 @@ public class InferenceSession extends AbstractSession { VarId vid = lookup(s, opInputs, allIterInputs, true); args[i] = nodeOutputs.get(vid); } - Preconditions.checkNotNull(args[i], "Could not parameterize op %s: array %s (variable %s) is null", opName, i, v.getVarName()); + Preconditions.checkNotNull(args[i], "Could not parameterize op %s: array %s (variable %s) is null", opName, i, v.name()); i++; } } @@ -825,7 +825,6 @@ public class InferenceSession extends AbstractSession { return sdo; } - df.resolvePropertiesFromSameDiffBeforeExecution(); //TODO This is to be removed List outShape = customOp.calculateOutputShape(); Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName()); String[] outNames = df.outputVariablesNames(); @@ -918,7 +917,6 @@ public class InferenceSession extends AbstractSession { op.setZ(z); } } - df.resolvePropertiesFromSameDiffBeforeExecution(); } return sdo; @@ -926,12 +924,12 @@ public class InferenceSession extends AbstractSession { protected INDArray getArray(SDVariable sdv, Collection opInputs, Collection allIterInputs) { - String n = sdv.getVarName(); + String n = sdv.name(); if (sdv.getVariableType() == VariableType.CONSTANT || sdv.getVariableType() == VariableType.VARIABLE) { return getConstantOrVariable(n); } else { VarId inVarId = lookup(n, opInputs, allIterInputs, false); - Preconditions.checkState(inVarId != null, "Could not find array for variable %s", sdv.getVarName()); + Preconditions.checkState(inVarId != null, "Could not find array for variable %s", sdv.name()); return nodeOutputs.get(inVarId); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java index 22032baf7..992a747a0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java @@ -88,9 +88,9 @@ public class TrainingSession extends InferenceSession { continue; } - requiredActivations.add(grad.getVarName()); + requiredActivations.add(grad.name()); - gradVarToVarMap.put(grad.getVarName(), s); + gradVarToVarMap.put(grad.name(), s); } //Set up losses diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index 7da89aa36..8d2e9f624 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -3266,7 +3266,7 @@ public abstract class SDBaseOps { if (cond_result.dataType() != DataType.BOOL) - throw new IllegalStateException("Can not use " + cond_result.getVarName() + " as the condition of an While loop, the condition must be a boolean."); + throw new IllegalStateException("Can not use " + cond_result.name() + " as the condition of an While loop, the condition must be a boolean."); final Set alreadyEntered = Sets.newHashSet(); @@ -3275,7 +3275,7 @@ public abstract class SDBaseOps { for(int i = 0 ; i < loopVars.length ; i++){ SDVariable[] s = f().switchOp(merged[i], cond_result); trueSwitches[i] = s[1]; - alreadyEntered.add(s[1].getVarName()); + alreadyEntered.add(s[1].name()); exits[i] = f().exit(s[0]); } @@ -3290,17 +3290,17 @@ public abstract class SDBaseOps { @Override public SDVariable intercept(SDVariable argument) { - if(!declared.contains(argument.getVarName())) + if(!declared.contains(argument.name())) return argument; - if(alreadyEntered.contains(argument.getVarName())) + if(alreadyEntered.contains(argument.name())) return argument; - if(done.containsKey(argument.getVarName())) - return done.get(argument.getVarName()); + if(done.containsKey(argument.name())) + return done.get(argument.name()); SDVariable e = f().enter(argument, frameName, true); - done.put(argument.getVarName(), e); + done.put(argument.name(), e); return e; } }); @@ -3371,7 +3371,7 @@ public abstract class SDBaseOps { //cleanup partially added block for(SDVariable v : sd().getVariablesInScope(ifScope)) - sd().getVariables().remove(v.getVarName()); + sd().getVariables().remove(v.name()); for(SameDiffOp op : sd().getOpsInScope(ifScope)) { for(String in : op.getInputsToOp()){ @@ -3381,7 +3381,7 @@ public abstract class SDBaseOps { } - throw new IllegalStateException("Can not use " + pred.getVarName() + throw new IllegalStateException("Can not use " + pred.name() + " as the condition of an If statement, the condition must be a boolean."); } @@ -3394,15 +3394,15 @@ public abstract class SDBaseOps { public SDVariable intercept(SDVariable argument) { // if its declared in the if, we don't care acout it - if(!declared.contains(argument.getVarName())) + if(!declared.contains(argument.name())) return argument; // if we've already added a switch, move on - if(switches.containsKey(argument.getVarName())) - return switches.get(argument.getVarName())[1]; + if(switches.containsKey(argument.name())) + return switches.get(argument.name())[1]; SDVariable[] s = f().switchOp(argument, pred); - switches.put(argument.getVarName(), s); + switches.put(argument.name(), s); return s[1]; } }); @@ -3410,9 +3410,9 @@ public abstract class SDBaseOps { SDVariable trueOut = trueBody.define(sd()); sd().removeArgumentInterceptor(); - if(declared.contains(trueOut.getVarName())) { + if(declared.contains(trueOut.name())) { SDVariable[] s = f().switchOp(trueOut, pred); - switches.put(trueOut.getVarName(), s); + switches.put(trueOut.name(), s); trueOut = s[1]; } @@ -3424,15 +3424,15 @@ public abstract class SDBaseOps { public SDVariable intercept(SDVariable argument) { // if its declared in the if, we don't care acout it - if(!declared2.contains(argument.getVarName())) + if(!declared2.contains(argument.name())) return argument; // if we've already added a switch, move on - if(switches.containsKey(argument.getVarName())) - return switches.get(argument.getVarName())[0]; + if(switches.containsKey(argument.name())) + return switches.get(argument.name())[0]; SDVariable[] s = f().switchOp(argument, pred); - switches.put(argument.getVarName(), s); + switches.put(argument.name(), s); return s[0]; } }); @@ -3440,9 +3440,9 @@ public abstract class SDBaseOps { SDVariable falseOut = falseBody.define(sd()); sd().removeArgumentInterceptor(); - if(declared2.contains(falseOut.getVarName())) { + if(declared2.contains(falseOut.name())) { SDVariable[] s = f().switchOp(falseOut, pred); - switches.put(falseOut.getVarName(), s); + switches.put(falseOut.name(), s); falseOut = s[0]; } falseScope.close(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java index d752facd4..f6434a56f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java @@ -37,7 +37,7 @@ public class SDValidation { if (v == null) return; if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) - throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-numerical data type " + v.dataType()); + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-numerical data type " + v.dataType()); } /** @@ -52,7 +52,7 @@ public class SDValidation { return; if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type; got variable \"" + - v.getVarName() + "\" with non-integer data type " + v.dataType()); + v.name() + "\" with non-integer data type " + v.dataType()); } /** @@ -65,8 +65,8 @@ public class SDValidation { */ protected static void validateNumerical(String opName, SDVariable v1, SDVariable v2) { if (v1.dataType() == DataType.BOOL || v1.dataType() == DataType.UTF8 || v2.dataType() == DataType.BOOL || v2.dataType() == DataType.UTF8) - throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.getVarName() + "\" and \"" + - v2.getVarName() + "\" if one or both variables are non-numerical: " + v1.dataType() + " and " + v2.dataType()); + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.name() + "\" and \"" + + v2.name() + "\" if one or both variables are non-numerical: " + v1.dataType() + " and " + v2.dataType()); } /** @@ -79,7 +79,7 @@ public class SDValidation { if (v == null) return; if (!v.dataType().isIntType()) - throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-integer data type " + v.dataType()); + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-integer data type " + v.dataType()); } /** @@ -94,7 +94,7 @@ public class SDValidation { return; if (!v.dataType().isIntType()) throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer type; got variable \"" + - v.getVarName() + "\" with non-integer data type " + v.dataType()); + v.name() + "\" with non-integer data type " + v.dataType()); } /** @@ -107,7 +107,7 @@ public class SDValidation { if (v == null) return; if (!v.dataType().isFPType()) - throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-floating point data type " + v.dataType()); + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-floating point data type " + v.dataType()); } /** @@ -122,7 +122,7 @@ public class SDValidation { return; if (!v.dataType().isFPType()) throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an floating point type; got variable \"" + - v.getVarName() + "\" with non-floating point data type " + v.dataType()); + v.name() + "\" with non-floating point data type " + v.dataType()); } /** @@ -135,7 +135,7 @@ public class SDValidation { if (v == null) return; if (v.dataType() != DataType.BOOL) - throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-boolean point data type " + v.dataType()); + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-boolean point data type " + v.dataType()); } /** @@ -150,7 +150,7 @@ public class SDValidation { return; if (v.dataType() != DataType.BOOL) throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an boolean variable; got variable \"" + - v.getVarName() + "\" with non-boolean data type " + v.dataType()); + v.name() + "\" with non-boolean data type " + v.dataType()); } /** @@ -162,8 +162,8 @@ public class SDValidation { */ protected static void validateBool(String opName, SDVariable v1, SDVariable v2) { if (v1.dataType() != DataType.BOOL || v2.dataType() != DataType.BOOL) - throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.getVarName() + "\" and \"" + - v2.getVarName() + "\" if one or both variables are non-boolean: " + v1.dataType() + " and " + v2.dataType()); + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.name() + "\" and \"" + + v2.name() + "\" if one or both variables are non-boolean: " + v1.dataType() + " and " + v2.dataType()); } /** @@ -190,7 +190,7 @@ public class SDValidation { String[] names = new String[vars.length]; DataType[] dtypes = new DataType[vars.length]; for (int j = 0; j < vars.length; j++) { - names[j] = vars[j].getVarName(); + names[j] = vars[j].name(); dtypes[j] = vars[j].dataType(); } throw new IllegalStateException("Cannot perform operation \"" + opName + "\" to variables with different datatypes:" + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index 39e7e479f..e89047ee4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -763,7 +763,7 @@ public class FlatBuffersMapper { SDVariable[] inputs = node.args(); for (SDVariable input : inputs) { - String varName = input.getVarName(); + String varName = input.name(); int outIdx; if (sameDiff.getVariables().get(varName).getOutputOfOp() != null) { DifferentialFunction df = sameDiff.getOps().get(sameDiff.getVariables().get(varName).getOutputOfOp()).getOp(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/GraphTransformUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/GraphTransformUtil.java index afe8551f3..44bd9b79e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/GraphTransformUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/GraphTransformUtil.java @@ -69,8 +69,8 @@ public class GraphTransformUtil { // we want to end up with (x -> A -> z) List allSubGraphFns = sg.allFunctionsInSubgraph(); for (int i = 0; i < oldOutputs.size(); i++) { - String oldOutVarName = oldOutputs.get(i).getVarName(); - String newOutVarName = newOutputs.get(i).getVarName(); + String oldOutVarName = oldOutputs.get(i).name(); + String newOutVarName = newOutputs.get(i).name(); Preconditions.checkState(!oldOutVarName.equals(newOutVarName), "Reusing old variables not yet implemented"); //Update inputs for ops: if X->opA, and now Y->opA, then X.inputsForOps contains "opA"; Y.inputsForOps should be updated @@ -133,7 +133,7 @@ public class GraphTransformUtil { //Step 2: Update input variables: if X -> (subgraph) exists, then X.inputsForOp needs to be updated List inputs = sg.inputs(); for (SDVariable v : inputs) { - Variable var = sd.getVariables().get(v.getVarName()); + Variable var = sd.getVariables().get(v.name()); if (var.getInputsForOp() != null) { List newInputsForOp = new ArrayList<>(var.getInputsForOp()); for (String opName : var.getInputsForOp()) { @@ -160,7 +160,7 @@ public class GraphTransformUtil { SDVariable[] outputs = df.outputVariables(); if (outputs != null) { for (SDVariable v : outputs) { - vars.remove(v.getVarName()); + vars.remove(v.name()); } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraph.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraph.java index c9f1f52bf..6514ee49e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraph.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraph.java @@ -62,7 +62,7 @@ public class SubGraph { //But suppose same subgraph, but connection y -> a exists; then Y must be an output, because it's used somewhere else List filteredOutputs = new ArrayList<>(allOutputs.size()); for(SDVariable v : allOutputs){ - Variable var = sameDiff.getVariables().get(v.getVarName()); + Variable var = sameDiff.getVariables().get(v.name()); List inputsFor = var.getInputsForOp(); boolean allInSubgraph = true; if(inputsFor != null){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraphPredicate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraphPredicate.java index 5d7e117a2..bf6ba09fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraphPredicate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraphPredicate.java @@ -77,7 +77,7 @@ public class SubGraphPredicate extends OpPredicate { } SDVariable in = inputs[inNum]; - DifferentialFunction df = sameDiff.getVariableOutputOp(in.getVarName()); + DifferentialFunction df = sameDiff.getVariableOutputOp(in.name()); if (df == null || !e.getValue().matches(sameDiff, df)) { return false; } @@ -103,7 +103,7 @@ public class SubGraphPredicate extends OpPredicate { for(Map.Entry entry : opInputSubgraphPredicates.entrySet()){ OpPredicate p2 = entry.getValue(); SDVariable arg = rootFn.arg(entry.getKey()); - DifferentialFunction df = sd.getVariableOutputOp(arg.getVarName()); + DifferentialFunction df = sd.getVariableOutputOp(arg.name()); if(df != null){ childNodes.add(df); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java index 336ec37d4..0f1e0bd52 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java @@ -107,7 +107,7 @@ public class GradCheckUtil { Set fnOutputs = new HashSet<>(); for(DifferentialFunction f : sd.ops()){ for(SDVariable s : f.outputVariables()){ - fnOutputs.add(s.getVarName()); + fnOutputs.add(s.name()); } } @@ -171,7 +171,7 @@ public class GradCheckUtil { Map grad = new HashMap<>(); for(SDVariable v : sd.variables()){ - if (fnOutputs.contains(v.getVarName())) { + if (fnOutputs.contains(v.name())) { //This is not an input to the graph continue; } @@ -179,20 +179,20 @@ public class GradCheckUtil { //Skip non-fp variables, or variables that don't impact loss function value continue; } - SDVariable g = sd.grad(v.getVarName()); + SDVariable g = sd.grad(v.name()); if(g == null){ - throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\""); + throw new IllegalStateException("Null gradient variable for \"" + v.name() + "\""); } - INDArray ga = gm.get(v.getVarName()); + INDArray ga = gm.get(v.name()); if(ga == null){ - throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName()); + throw new IllegalStateException("Null gradient array encountered for variable: " + v.name()); } if(!Arrays.equals(v.getArr().shape(), ga.shape())){ throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + - v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " + + v.name() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " + Arrays.toString(ga.shape())); } - grad.put(v.getVarName(), ga.dup()); + grad.put(v.name(), ga.dup()); } //Validate gradients for each variable: @@ -201,25 +201,25 @@ public class GradCheckUtil { double maxError = 0.0; Random r = new Random(12345); for(SDVariable s : sd.variables()){ - if (fnOutputs.contains(s.getVarName()) || !s.dataType().isFPType()) { + if (fnOutputs.contains(s.name()) || !s.dataType().isFPType()) { //This is not an input to the graph, or is not a floating point input (so can't be gradient checked) continue; } - if(skipVariables != null && skipVariables.contains(s.getVarName())){ - log.info("Grad check: skipping variable \"{}\"", s.getVarName()); + if(skipVariables != null && skipVariables.contains(s.name())){ + log.info("Grad check: skipping variable \"{}\"", s.name()); continue; } if(s.dataType() != DataType.DOUBLE){ - log.warn("DataType for variable {} is not double (is: {}) may cause precision issues in gradient checks", s.getVarName(), s.dataType()); + log.warn("DataType for variable {} is not double (is: {}) may cause precision issues in gradient checks", s.name(), s.dataType()); } - String name = s.getVarName(); + String name = s.name(); INDArray a = s.getArr(); long n = a.length(); if(print){ - log.info("Starting test for variable \"{}\" with {} values", s.getVarName(), n); + log.info("Starting test for variable \"{}\" with {} values", s.name(), n); } Iterator iter; @@ -256,11 +256,11 @@ public class GradCheckUtil { iter = new NdIndexIterator('c',a.shape()); } - INDArray varMask = (gradCheckMask == null ? null : gradCheckMask.get(s.getVarName())); + INDArray varMask = (gradCheckMask == null ? null : gradCheckMask.get(s.name())); if(varMask != null){ - Preconditions.checkState(a.equalShapes(varMask), "Variable \"%s\": Gradient check mask and array shapes must be equal: got %s vs. mask shape %s", s.getVarName(), a.shape(), varMask.shape()); - Preconditions.checkState(varMask.dataType() == DataType.BOOL, "Variable \"%s\": Gradient check mask must be BOOLEAN datatype, got %s", s.getVarName(), varMask.dataType()); + Preconditions.checkState(a.equalShapes(varMask), "Variable \"%s\": Gradient check mask and array shapes must be equal: got %s vs. mask shape %s", s.name(), a.shape(), varMask.shape()); + Preconditions.checkState(varMask.dataType() == DataType.BOOL, "Variable \"%s\": Gradient check mask must be BOOLEAN datatype, got %s", s.name(), varMask.dataType()); } int i=0; @@ -281,12 +281,12 @@ public class GradCheckUtil { double orig = a.getDouble(idx); a.putScalar(idx, orig+eps); double scorePlus = 0.0; - Map m = sd.exec(placeholderValues, lossFnVariables);//.get(outName).sumNumber().doubleValue(); + Map m = sd.output(placeholderValues, lossFnVariables);//.get(outName).sumNumber().doubleValue(); for(INDArray arr : m.values()){ scorePlus += arr.sumNumber().doubleValue(); } a.putScalar(idx, orig-eps); - m = sd.exec(placeholderValues, lossFnVariables); + m = sd.output(placeholderValues, lossFnVariables); double scoreMinus = 0.0; for(INDArray arr : m.values()){ scoreMinus += arr.sumNumber().doubleValue(); @@ -294,9 +294,9 @@ public class GradCheckUtil { a.putScalar(idx, orig); double numericalGrad = (scorePlus - scoreMinus) / (2 * eps); - INDArray aGrad = grad.get(s.getVarName()); + INDArray aGrad = grad.get(s.name()); if(aGrad == null){ - log.warn("No gradient array for variable \"{}\" was found, skipping variable...", s.getVarName()); + log.warn("No gradient array for variable \"{}\" was found, skipping variable...", s.name()); continue; } double analyticGrad = aGrad.getDouble(idx); @@ -497,12 +497,12 @@ public class GradCheckUtil { listener.setIdx(idx); listener.setEps(config.getEps()); double scorePlus = 0.0; - Map m = sd.exec(config.getPlaceholderValues(), lossFnVariables); + Map m = sd.output(config.getPlaceholderValues(), lossFnVariables); for(INDArray arr : m.values()){ scorePlus += arr.sumNumber().doubleValue(); } listener.setEps(-config.getEps()); - m = sd.exec(config.getPlaceholderValues(), lossFnVariables); + m = sd.output(config.getPlaceholderValues(), lossFnVariables); double scoreMinus = 0.0; for(INDArray arr : m.values()){ scoreMinus += arr.sumNumber().doubleValue(); @@ -597,10 +597,10 @@ public class GradCheckUtil { Set varSetStr = new HashSet<>(); for(SDVariable v : vars){ - if(varSetStr.contains(v.getVarName())){ - throw new IllegalStateException("Variable with name " + v.getVarName() + " already encountered"); + if(varSetStr.contains(v.name())){ + throw new IllegalStateException("Variable with name " + v.name() + " already encountered"); } - varSetStr.add(v.getVarName()); + varSetStr.add(v.name()); } Preconditions.checkState(vars.size() == varSetStr.size(), "Duplicate variables in variables() list"); @@ -645,7 +645,7 @@ public class GradCheckUtil { Map variableMap = sd.getVariables(); Preconditions.checkState(vars.size() == variableMap.size(), "Variable map size check failed"); for(Map.Entry e : variableMap.entrySet()){ - Preconditions.checkState(e.getKey().equals(e.getValue().getVariable().getVarName()), "Name not equal"); + Preconditions.checkState(e.getKey().equals(e.getValue().getVariable().name()), "Name not equal"); } if(generateAndCheckGradFn) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index 218004c67..fc7572180 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -208,7 +208,7 @@ public class OpValidation { e.getKey() + "\" but SameDiff instance does not have a variable for this name" + testCase.testNameErrMsg()); } - INDArray actual = out.get(v.getVarName()); + INDArray actual = out.get(v.name()); if (actual == null) { throw new IllegalStateException("Null INDArray after forward pass for variable \"" + e.getKey() + "\""); } @@ -271,8 +271,8 @@ public class OpValidation { for( int i=0; i validationFn){ - return expected(var.getVarName(), validationFn); + return expected(var.name(), validationFn); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/ui/LogFileWriter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/ui/LogFileWriter.java index 0165a18f3..7bba2b765 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/ui/LogFileWriter.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/ui/LogFileWriter.java @@ -487,11 +487,15 @@ public class LogFileWriter { //Create outputs list: List outputs = sd.outputs(); - int[] outputListStrOffsets = new int[outputs.size()]; - for (int i = 0; i < outputListStrOffsets.length; i++) { - outputListStrOffsets[i] = fbb.createString(outputs.get(i)); + int outputsOffset = 0; + if(outputs != null && !outputs.isEmpty()) { + int[] outputListStrOffsets = new int[outputs.size()]; + for (int i = 0; i < outputListStrOffsets.length; i++) { + outputListStrOffsets[i] = fbb.createString(outputs.get(i)); + } + outputsOffset = UIGraphStructure.createInputsVector(fbb, outputListStrOffsets); } - int outputsOffset = UIGraphStructure.createInputsVector(fbb, outputListStrOffsets); + //Create variables list Map varMap = sd.getVariables(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 7bdea70b5..f7bbc0620 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -46,6 +46,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class, org.nd4j.linalg.api.ops.custom.BarnesHutGains.class, org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class, + org.nd4j.linalg.api.ops.custom.KnnMinDistance.class, org.nd4j.linalg.api.ops.custom.SpTreeCell.class, org.nd4j.linalg.api.ops.custom.Flatten.class, org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd.class, @@ -322,7 +323,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.BinCount.class, org.nd4j.linalg.api.ops.impl.transforms.CheckNumerics.class, org.nd4j.linalg.api.ops.impl.transforms.Cholesky.class, - org.nd4j.linalg.api.ops.impl.transforms.Constant.class, org.nd4j.linalg.api.ops.impl.transforms.Histogram.class, org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth.class, org.nd4j.linalg.api.ops.impl.transforms.IdentityN.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java index be90bb545..233609b19 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java @@ -810,8 +810,6 @@ public class TFGraphMapper { on.setValueFor(currentField, tensor.getFloat(0)); } } - } else { - on.getSameDiff().addPropertyToResolve(on, entry.getKey()); } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java index a41dc8790..4a2e66037 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java @@ -63,19 +63,11 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp this.sameDiff = sameDiff; this.inPlace = inPlace; this.dimension = dimension; - if(Shape.isPlaceholderShape(i_v1.getShape())) { - sameDiff.addPropertyToResolve(this,i_v1.getVarName()); - } - - if(Shape.isPlaceholderShape(i_v2.getShape())) { - sameDiff.addPropertyToResolve(this,i_v2.getVarName()); - } sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); } else { throw new IllegalArgumentException("Input not null variables."); } - } public BaseBroadcastBoolOp(SameDiff sameDiff) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java index 7f0d7e40c..6e8b1afa4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java @@ -64,19 +64,10 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp { this.sameDiff = sameDiff; this.inPlace = inPlace; this.dimension = dimension; - if(Shape.isPlaceholderShape(i_v1.getShape())) { - sameDiff.addPropertyToResolve(this,i_v1.getVarName()); - } - - if(Shape.isPlaceholderShape(i_v2.getShape())) { - sameDiff.addPropertyToResolve(this,i_v2.getVarName()); - } sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); - } else { throw new IllegalArgumentException("Input not null variables."); } - } public BaseBroadcastOp(SameDiff sameDiff) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java index 6e5682962..445a0d585 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java @@ -53,11 +53,8 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum this.dimensions = dimensions; f().validateDifferentialFunctionsameDiff(i_v); sameDiff.addArgsFor(new SDVariable[]{i_v},this); - if(Shape.isPlaceholderShape(i_v.getShape())) { - sameDiff.addPropertyToResolve(this,i_v.getVarName()); - } - this.xVertexId = i_v.getVarName(); + this.xVertexId = i_v.name(); } else { throw new IllegalArgumentException("Input not null variable."); } @@ -75,17 +72,9 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum this.dimensions = dimensions; f().validateDifferentialFunctionsameDiff(i_v); f().validateDifferentialFunctionsameDiff(i_v2); - this.xVertexId = i_v.getVarName(); - this.yVertexId = i_v2.getVarName(); + this.xVertexId = i_v.name(); + this.yVertexId = i_v2.name(); sameDiff.addArgsFor(new SDVariable[]{i_v,i_v2},this); - - if(Shape.isPlaceholderShape(i_v.getShape())) { - sameDiff.addPropertyToResolve(this,i_v.getVarName()); - } - - if(Shape.isPlaceholderShape(i_v2.getShape())) { - sameDiff.addPropertyToResolve(this,i_v2.getVarName()); - } } else { throw new IllegalArgumentException("Input not null variable."); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java index 5c0577aca..65bdb5c1e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java @@ -247,7 +247,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op { val outputNames = sameDiff.getOutputsForOp(this); //no need to dynamically create if already exists if(outputNames != null) { - zVertexId = sameDiff.getVariable(outputNames[0]).getVarName(); + zVertexId = sameDiff.getVariable(outputNames[0]).name(); return new SDVariable[]{sameDiff.getVariable(outputNames[0])}; @@ -261,7 +261,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op { return newVars; } - sameDiff.setArrayForVariable(newVars[0].getVarName(),inputArr); + sameDiff.setArrayForVariable(newVars[0].name(),inputArr); z = inputArr; if(sameDiff.getOutputsForOp(this) == null) sameDiff.addOutgoingFor(newVars,this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java index 1e06e7f52..9e5b8f67b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java @@ -61,7 +61,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { this.dimensions = dimensions; f().validateDifferentialFunctionsameDiff(i_v); this.keepDims = keepDims; - this.xVertexId = i_v.getVarName(); + this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); } else { throw new IllegalArgumentException("Input not null variable."); @@ -81,8 +81,8 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { this.dimensions = dimensions; - this.xVertexId = i_v.getVarName(); - this.yVertexId = i_v2.getVarName(); + this.xVertexId = i_v.name(); + this.yVertexId = i_v2.name(); f().validateDifferentialFunctionsameDiff(i_v); f().validateDifferentialFunctionsameDiff(i_v2); this.keepDims = keepDims; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java index 54b53f489..3abca3db2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java @@ -74,11 +74,8 @@ public abstract class BaseScalarBoolOp extends BaseOp implements ScalarOp { super(sameDiff,inPlace,extraArgs); this.scalarValue = Nd4j.scalar(i_v.dataType(), scalar); if (i_v != null) { - this.xVertexId = i_v.getVarName(); + this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); - if(Shape.isPlaceholderShape(i_v.getShape())) { - sameDiff.addPropertyToResolve(this,i_v.getVarName()); - } f().validateDifferentialFunctionsameDiff(i_v); } else { throw new IllegalArgumentException("Input not null variable."); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java index 0048c9402..ce74a7cd1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java @@ -93,11 +93,8 @@ public abstract class BaseScalarOp extends BaseOp implements ScalarOp { Object[] extraArgs) { super(sameDiff,inPlace,extraArgs); this.scalarValue = Nd4j.scalar(i_v.dataType(), scalar); - this.xVertexId = i_v.getVarName(); + this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); - if(Shape.isPlaceholderShape(i_v.getShape())) { - sameDiff.addPropertyToResolve(this,i_v.getVarName()); - } f().validateDifferentialFunctionsameDiff(i_v); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java index 8afc68a52..274dfdf72 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java @@ -56,16 +56,9 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { f().validateDifferentialFunctionsameDiff(i_v2); this.sameDiff = sameDiff; this.inPlace = inPlace; - this.xVertexId = i_v1.getVarName(); - this.yVertexId = i_v2.getVarName(); + this.xVertexId = i_v1.name(); + this.yVertexId = i_v2.name(); sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); - if(Shape.isPlaceholderShape(i_v1.getShape())) { - sameDiff.addPropertyToResolve(this,i_v1.getVarName()); - } - - if(Shape.isPlaceholderShape(i_v2.getShape())) { - sameDiff.addPropertyToResolve(this,i_v2.getVarName()); - } } else { throw new IllegalArgumentException("Input not null variables."); } @@ -87,18 +80,9 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { f().validateDifferentialFunctionsameDiff(i_v1); f().validateDifferentialFunctionsameDiff(i_v2); this.sameDiff = sameDiff; - this.xVertexId = i_v1.getVarName(); - this.yVertexId = i_v2.getVarName(); + this.xVertexId = i_v1.name(); + this.yVertexId = i_v2.name(); sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); - - if(Shape.isPlaceholderShape(i_v1.getShape())) { - sameDiff.addPropertyToResolve(this,i_v1.getVarName()); - } - - if(Shape.isPlaceholderShape(i_v2.getShape())) { - sameDiff.addPropertyToResolve(this,i_v2.getVarName()); - } - } else { throw new IllegalArgumentException("Input not null variables."); } @@ -130,14 +114,8 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { if (i_v != null) { f().validateDifferentialFunctionsameDiff(i_v); - this.xVertexId = i_v.getVarName(); + this.xVertexId = i_v.name(); sameDiff.addArgsFor(new SDVariable[]{i_v},this); - - if(Shape.isPlaceholderShape(i_v.getShape())) { - sameDiff.addPropertyToResolve(this,i_v.getVarName()); - } - - } else { throw new IllegalArgumentException("Input must not null variable."); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index 6ebaa5120..99e930176 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -223,7 +223,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { if (args().length >= 1) { val arr = args()[0].getArr(); if (arr != null) { - sameDiff.setArrayForVariable(newVars[0].getVarName(), arr); + sameDiff.setArrayForVariable(newVars[0].name(), arr); addOutputArgument(arr); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java index fd2134aad..208daebf2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java @@ -57,7 +57,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp { public ExternalErrorsFunction(){ } public String getGradPlaceholderName(){ - return arg().getVarName() + "-grad"; + return arg().name() + "-grad"; } @Override @@ -70,7 +70,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp { out = sameDiff.getVariable(name); } else { out = sameDiff.zero(name, Nd4j.dataType(), 1); - sameDiff.getOps().get(getOwnName()).setOutputsOfOp(Collections.singletonList(out.getVarName())); + sameDiff.getOps().get(getOwnName()).setOutputsOfOp(Collections.singletonList(out.name())); sameDiff.getVariables().get(name).setOutputOfOp(getOwnName()); } } @@ -83,7 +83,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp { if (gradVariables == null) { gradVariables = new HashMap<>(); for(SDVariable arg : args()){ - INDArray gradArr = gradients.get(arg.getVarName()); + INDArray gradArr = gradients.get(arg.name()); SDVariable grad; DataType dt = arg.dataType(); String n = getGradPlaceholderName(); @@ -94,7 +94,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp { } else { grad = sameDiff.var(n, VariableType.PLACEHOLDER, null, dt); } - gradVariables.put(arg.getVarName(), grad); + gradVariables.put(arg.name(), grad); out.add(grad); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java index dc92d826d..017d341d6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java @@ -196,12 +196,12 @@ public class DeConv2D extends DynamicCustomOp { val paddingMode = aPadding.getS().toStringUtf8(); val args = args(); - INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr(); + INDArray arr = sameDiff.getVariable(args[1].name()).getArr(); if (arr == null) { arr = TFGraphMapper.getNDArrayFromTensor(nodeDef); // TODO: arguable. it might be easier to permute weights once //arr = (arr.permute(3, 2, 0, 1).dup('c')); - val varForOp = initWith.getVariable(args[1].getVarName()); + val varForOp = initWith.getVariable(args[1].name()); if (arr != null) initWith.associateArrayWithVariable(arr, varForOp); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java index 6a3c8854f..0a153422d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java @@ -158,10 +158,10 @@ public class DeConv3D extends DynamicCustomOp { val paddingMode = aPadding.getS().toStringUtf8(); val args = args(); - INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr(); + INDArray arr = sameDiff.getVariable(args[1].name()).getArr(); if (arr == null) { arr = TFGraphMapper.getNDArrayFromTensor(nodeDef); - val varForOp = initWith.getVariable(args[1].getVarName()); + val varForOp = initWith.getVariable(args[1].name()); if (arr != null) initWith.associateArrayWithVariable(arr, varForOp); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index 7d711ca58..fa6bceef7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -193,12 +193,6 @@ public class Mmul extends DynamicCustomOp { .transposeA(isTransposeA).transposeB(isTransposeB) .build(); this.mt = mMulTranspose; - val args = args(); - for(val arg : args) { - if(sameDiff.isPlaceHolder(arg.getVarName()) || arg.getShape() == null) { - sameDiff.addPropertyToResolve(this,arg.getVarName()); - } - } iArguments.clear(); addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB())); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java index d40a3a334..c860152ca 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java @@ -130,10 +130,7 @@ public class Concat extends DynamicCustomOp { val variable = initWith.getVariable(input); // concat dimension is only possible - if (variable != null && variable.getArr() == null) { - sameDiff.addPropertyToResolve(this, input); - - } else if (variable != null) { + if (variable != null) { val arr = variable.getArr(); if (arr.length() == 1) { concatDimension = arr.getInt(0); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java index 41224fbb7..32fd91b3a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java @@ -124,13 +124,7 @@ public class Transpose extends DynamicCustomOp { return; } - INDArray arr = sameDiff.getArrForVarName(arg().getVarName()); - if (arr == null) { - val arrVar = sameDiff.getVariable(arg().getVarName()); - - arr = arrVar.getWeightInitScheme().create(arrVar.dataType(), arrVar.getShape()); - sameDiff.setArrayForVariable(arg().getVarName(), arr); - } + INDArray arr = sameDiff.getArrForVarName(arg().name()); if(permuteArrayOp != null){ addInputArgument(arr, permuteArrayOp); @@ -138,16 +132,12 @@ public class Transpose extends DynamicCustomOp { addInputArgument(arr); } - - if (arr != null && permuteDims == null) { this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank())); } if (permuteDims != null && permuteDims.length < arg().getShape().length) throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified"); - - } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java index 5b3584d11..b32687844 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java @@ -72,7 +72,7 @@ public class TensorArrayConcat extends BaseTensorOp { public List calculateOutputDataTypes(java.util.List inputDataType){ //Same output type as the TensorArray - which is defined by input 0 SDVariable tArr = arg(0); - TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.getVarName()); + TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.name()); org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType(); return Collections.singletonList(dt); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java index 185ef429a..70d200dea 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java @@ -72,7 +72,7 @@ public class TensorArrayGather extends BaseTensorOp { public List calculateOutputDataTypes(java.util.List inputDataType){ //Same output type as the TensorArray - which is defined by input 0 SDVariable tArr = arg(0); - TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.getVarName()); + TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.name()); org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType(); return Collections.singletonList(dt); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java index 2b8114543..a92185e57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java @@ -73,7 +73,7 @@ public class TensorArrayRead extends BaseTensorOp { dt = importDataType; } else { SDVariable tArr = arg(0); - DifferentialFunction op = sameDiff.getVariableOutputOp(tArr.getVarName()); + DifferentialFunction op = sameDiff.getVariableOutputOp(tArr.name()); TensorArray t3 = (TensorArray) op; dt = t3.getTensorArrayDataType(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/CheckNumerics.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/CheckNumerics.java index 78fdcabba..3334c9bbd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/CheckNumerics.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/CheckNumerics.java @@ -71,9 +71,9 @@ public class CheckNumerics extends DynamicCustomOp { SDVariable msg = initWith.constant(name + "/message", Nd4j.scalar(str)); List newInputs = new ArrayList<>(2); newInputs.addAll(initWith.getOps().get(name).getInputsToOp()); - newInputs.add(msg.getVarName()); + newInputs.add(msg.name()); initWith.getOps().get(name).setInputsToOp(newInputs); - initWith.getVariables().get(msg.getVarName()).setInputsForOp(Collections.singletonList(getOwnName())); } + initWith.getVariables().get(msg.name()).setInputsForOp(Collections.singletonList(getOwnName())); } @Override public List calculateOutputDataTypes(List inputDataTypes){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Constant.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Constant.java deleted file mode 100644 index e003f4325..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Constant.java +++ /dev/null @@ -1,91 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms; - -import lombok.Data; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformSameOp; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.UUID; - -@Data -public class Constant extends BaseTransformSameOp { - - - public Constant() { - } - - - protected Constant(SameDiff sameDiff, - SDVariable i_v, - long[] shape, - boolean inPlace) { - super(); - sameDiff.putOrUpdateShapeForVarName(i_v.getVarName(), shape, false); - this.xVertexId = i_v.getVarName(); - this.inPlace = inPlace; - this.sameDiff = sameDiff; - } - - public Constant(SameDiff sameDiff, SDVariable i_v, long[] shape) { - this(sameDiff, i_v, shape, false); - } - - - @Override - public List doDiff(List i_v) { - return Collections.singletonList(sameDiff.zerosLike(arg())); - } - - - @Override - public DifferentialFunction dup() { - Constant ret = new Constant(sameDiff, sameDiff.getVariable(outputVariables()[0].getVarName()) - , sameDiff.getShapeForVarName(outputVariables()[0].getVarName())); - Constant differentialFunction = ret; - return differentialFunction; - } - - - @Override - public int opNum() { - return 15; - } - - @Override - public String opName() { - return "constant"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow opName found for " + opName()); - } - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java index 052700b0a..05993cd7f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java @@ -118,7 +118,7 @@ public class MaxOut extends BaseTransformOp { if(arg() == null) throw new ND4JIllegalStateException("No arg found for op!"); - val arr = sameDiff.getArrForVarName(arg().getVarName()); + val arr = sameDiff.getArrForVarName(arg().name()); if(arr == null) return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java index 4d8209b8a..16afb4316 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java @@ -28,13 +28,19 @@ import java.util.Collections; import java.util.List; /** - * + * TanhDerivative: calculated dL/dIn from dL/dOut and In */ public class TanhDerivative extends DynamicCustomOp { public TanhDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { super(sameDiff, new SDVariable[]{i_v1, i_v2}); } + /** + * + * @param x Input + * @param y Gradient at output (dL/dOut) + * @param z Output array, gradient at input (dL/dIn - to be calculated) + */ public TanhDerivative(INDArray x, INDArray y, INDArray z) { super(null, new INDArray[]{x, y}, new INDArray[]{z}); } @@ -42,6 +48,10 @@ public class TanhDerivative extends DynamicCustomOp { public TanhDerivative() { } + /** + * @param x Input + * @param y Gradient at output (dL/dOut) + */ public TanhDerivative(INDArray x, INDArray y) { this(x, y, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java index f3fb41464..763a64fc4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java @@ -43,11 +43,8 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp { public BaseRandomOp(SameDiff sameDiff, SDVariable i_v) { Preconditions.checkNotNull(i_v, "Input variable can't be null with this constructor"); this.sameDiff = sameDiff; - this.xVertexId = i_v.getVarName(); + this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); - if(Shape.isPlaceholderShape(i_v.getShape())) { - sameDiff.addPropertyToResolve(this,i_v.getVarName()); - } } public BaseRandomOp(SameDiff sd, long[] shape){ @@ -73,11 +70,7 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp { if(shape != null){ return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Nd4j.defaultFloatingPointType())); } else { - List ret = new ArrayList<>(1); - val shape = sameDiff.getShapeForVarName(args()[0].getVarName()); - if (shape != null) - ret.add(LongShapeDescriptor.fromShape(shape, Shape.pickPairwiseDataType(args()[0].dataType(), Nd4j.dataType()))); - return ret; + return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Shape.pickPairwiseDataType(args()[0].dataType(), Nd4j.dataType()))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 08a546faa..17a8e8a36 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -5212,6 +5212,8 @@ public class Nd4j { } } } + + backend.logBackendInit(); } catch (Exception e) { throw new RuntimeException(e); } @@ -5625,19 +5627,38 @@ public class Nd4j { * @return an ndarray created from the in memory * numpy pointer */ - @SuppressWarnings("WeakerAccess") public static INDArray createFromNpyPointer(Pointer pointer) { return INSTANCE.createFromNpyPointer(pointer); } /** - * Create from a given Numpy .npy file. + * Create an INDArray from a given Numpy .npy file. + * + * @param path Path to the .npy file to read + * @return the created ndarray + */ + public static INDArray readNpy(@NonNull String path){ + return readNpy(new File(path)); + } + + /** + * Create an INDArray from a given Numpy .npy file. * * @param file the file to create the ndarray from * @return the created ndarray */ - public static INDArray createFromNpyFile(File file) { + public static INDArray readNpy(@NonNull File file){ + return createFromNpyFile(file); + } + + /** + * Create an INDArray from a given Numpy .npy file. + * + * @param file the file to create the ndarray from + * @return the created ndarray + */ + public static INDArray createFromNpyFile(@NonNull File file) { if (!file.exists()) throw new IllegalArgumentException("File [" + file.getAbsolutePath() + "] doesn't exist"); @@ -5654,7 +5675,7 @@ public class Nd4j { * @return the loaded ndarray */ @SuppressWarnings("unused") - public static INDArray createNpyFromInputStream(InputStream is) throws IOException { + public static INDArray createNpyFromInputStream(@NonNull InputStream is) throws IOException { byte[] content = IOUtils.toByteArray(is); return createNpyFromByteArray(content); } @@ -5668,7 +5689,7 @@ public class Nd4j { * @param input the input byte array with the npy format * @return the equivalent {@link INDArray} */ - public static INDArray createNpyFromByteArray(byte[] input) { + public static INDArray createNpyFromByteArray(@NonNull byte[] input) { ByteBuffer byteBuffer = ByteBuffer.allocateDirect(input.length); byteBuffer.put(input); byteBuffer.rewind(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java index ae31ea7b8..de9edfc2e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java @@ -20,6 +20,7 @@ import java.util.Properties; import lombok.Getter; import org.bytedeco.javacpp.Loader; import org.nd4j.config.ND4JEnvironmentVars; +import org.nd4j.config.ND4JSystemProperties; import org.nd4j.context.Nd4jContext; import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; @@ -101,7 +102,12 @@ public class NativeOpsHolder { } //deviceNativeOps.setOmpNumThreads(4); - log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads()); + String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true"); + boolean logInit = Boolean.parseBoolean(logInitProperty); + + if(logInit) { + log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads()); + } } catch (Exception | Error e) { throw new RuntimeException( "ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html", diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/MemoryTracker.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/MemoryTracker.java index c790b40ab..c7eada206 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/MemoryTracker.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/MemoryTracker.java @@ -47,7 +47,7 @@ public class MemoryTracker { val f = new AtomicLong(NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceFreeMemory(i)); - log.debug("Free memory on device_{}: {}", i, f); + //log.debug("Free memory on device_{}: {}", i, f); freePerDevice.add(i, f); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java index a6a5a45e4..34970dc19 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java @@ -16,14 +16,24 @@ package org.nd4j.linalg.jcublas; +import lombok.extern.slf4j.Slf4j; import org.bytedeco.javacpp.Loader; +import org.nd4j.config.ND4JSystemProperties; +import org.nd4j.linalg.api.environment.Nd4jEnvironment; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.Resource; +import org.nd4j.nativeblas.Nd4jCuda; + +import java.util.List; +import java.util.Map; +import java.util.Properties; /** * */ +@Slf4j public class JCublasBackend extends Nd4jBackend { @@ -76,4 +86,34 @@ public class JCublasBackend extends Nd4jBackend { return JCublasNDArray.class; } + @Override + public void logBackendInit() { + String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true"); + boolean logInit = Boolean.parseBoolean(logInitProperty); + + if(logInit) { + try { + Nd4jCuda.Environment e = Nd4jCuda.Environment.getInstance(); + int blasMajor = e.blasMajorVersion(); + int blasMinor = e.blasMinorVersion(); + int blasPatch = e.blasPatchVersion(); + log.info("ND4J CUDA build version: {}.{}.{}", blasMajor, blasMinor, blasPatch); + int nGPUs = Nd4jEnvironment.getEnvironment().getNumGpus(); + + Properties props = Nd4j.getExecutioner().getEnvironmentInformation(); + List> devicesList = (List>) props.get(Nd4jEnvironment.CUDA_DEVICE_INFORMATION_KEY); + + for (int i = 0; i < nGPUs; i++) { + Map dev = devicesList.get(i); + String name = (String) dev.get(Nd4jEnvironment.CUDA_DEVICE_NAME_KEY); + int major = ((Number) dev.get(Nd4jEnvironment.CUDA_DEVICE_MAJOR_VERSION_KEY)).intValue(); + int minor = ((Number) dev.get(Nd4jEnvironment.CUDA_DEVICE_MINOR_VERSION_KEY)).intValue(); + long totalMem = ((Number) dev.get(Nd4jEnvironment.CUDA_TOTAL_MEMORY_KEY)).longValue(); + log.info("CUDA device {}: [{}]; cc: [{}.{}]; Total memory: [{}]", i, name, major, minor, totalMem); + } + } catch (Throwable t) { + log.debug("Error logging CUDA backend versions and devices", t); + } + } + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 6c95d3ce5..057e0bbca 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -1890,14 +1890,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { @SuppressWarnings("unchecked") public void printEnvironmentInformation() { super.printEnvironmentInformation(); - - Properties env = getEnvironmentInformation(); - - List> devicesList = (List>) env.get(Nd4jEnvironment.CUDA_DEVICE_INFORMATION_KEY); - for (Map dev : devicesList) { - log.info("Device Name: [{}]; CC: [{}.{}]; Total/free memory: [{}]", dev.get(Nd4jEnvironment.CUDA_DEVICE_NAME_KEY), - dev.get(Nd4jEnvironment.CUDA_DEVICE_MAJOR_VERSION_KEY), dev.get(Nd4jEnvironment.CUDA_DEVICE_MINOR_VERSION_KEY), dev.get(Nd4jEnvironment.CUDA_TOTAL_MEMORY_KEY)); - } } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuBackend.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuBackend.java index 7ba2aa6d7..627105bda 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuBackend.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuBackend.java @@ -60,4 +60,9 @@ public class CpuBackend extends Nd4jBackend { public Class getNDArrayClass() { return NDArray.class; } + + @Override + public void logBackendInit() { + //No additional logging for CPU backend + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java index 24a05d73e..09e94acc7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java @@ -146,8 +146,8 @@ public class TestSessions extends BaseNd4jTest { System.out.println("----------------------------------"); InferenceSession is = new InferenceSession(sd); -// String outName = merge.getVarName(); - String outName = outVar.getVarName(); +// String outName = merge.name(); + String outName = outVar.name(); Map outMap = is.output(Collections.singletonList(outName), m, null, Collections.emptyList(), null, At.defaultAt(Operation.TRAINING)); @@ -181,7 +181,7 @@ public class TestSessions extends BaseNd4jTest { m.put("b", bArr); InferenceSession is = new InferenceSession(sd); - String n = merge.getVarName(); + String n = merge.name(); System.out.println("----------------------------------"); Map outMap = is.output(Collections.singletonList(n), m, null, Collections.emptyList(), diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java index 177a4b795..8fee7f888 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java @@ -118,7 +118,7 @@ public class GraphExecutionerTest extends BaseNd4jTest { SDVariable result = sdVariable.add(scalarOne); SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE); - log.info("TOTAL: {}; Id: {}", total.getVarName(), total); + log.info("TOTAL: {}; Id: {}", total.name(), total); INDArray[] resB = executionerB.executeGraph(sameDiff, configVarSpace); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 7939bcdae..b84d7ceea 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -79,7 +79,7 @@ public class LayerOpValidation extends BaseOpValidation { TestCase tc = new TestCase(sameDiff) .gradientCheck(true) - .expectedOutput(res.getVarName(), exp); + .expectedOutput(res.name(), exp); System.out.println(sameDiff.summary()); System.out.println("============================"); @@ -112,7 +112,7 @@ public class LayerOpValidation extends BaseOpValidation { TestCase tc = new TestCase(sameDiff) .gradientCheck(true) - .expectedOutput(res.getVarName(), exp); + .expectedOutput(res.name(), exp); String err = OpValidation.validate(tc); @@ -137,7 +137,7 @@ public class LayerOpValidation extends BaseOpValidation { TestCase tc = new TestCase(sameDiff) .gradientCheck(true) - .expectedOutput(res.getVarName(), exp); + .expectedOutput(res.name(), exp); String err = OpValidation.validate(tc); assertNull(err); @@ -591,7 +591,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable out = sd.cnn().sconv2d(vars, c); out = sd.nn().tanh("out", out); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 val outShape = outArr.shape(); assertArrayEquals(new long[]{mb, depthWise * nIn, 27, 27}, outShape); @@ -637,7 +637,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable out = sd.cnn().sconv2d(vars, c); out = sd.nn().tanh("out", out); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (8-2+0)/1+1 = 7 val outShape = outArr.shape(); assertArrayEquals(new long[]{mb, nOut, 7, 7}, outShape); @@ -688,7 +688,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable out = sd.cnn().deconv2d(vars, deconv); out = sd.nn().tanh("out", out); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); //Expected output size: out = (in + k + 2*p)/ s - 1 = (8 + 2+0)/1 - 1 = 9 val outShape = outArr.shape(); assertArrayEquals(new long[]{mb, nOut, 9, 9}, outShape); @@ -736,7 +736,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable out = sd.cnn().conv2d("conv", vars, c); out = sd.nn().tanh("out", out); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 val outShape = outArr.shape(); assertArrayEquals(new long[]{mb, nOut, 27, 27}, outShape); @@ -770,7 +770,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable outPool = sd.cnn().maxPooling2d(in, pooling2DConfig); SDVariable out = sd.nn().tanh("out", outPool); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); val outShape = outArr.shape(); // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; assertArrayEquals(new long[]{mb, nIn, 7, 7}, outShape); @@ -828,7 +828,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable outPool = sd.cnn().avgPooling2d(in, pooling2DConfig); SDVariable out = sd.nn().tanh("out", outPool); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); val outShape = outArr.shape(); // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; assertArrayEquals(new long[]{mb, nIn, 7, 7}, outShape); @@ -996,7 +996,7 @@ public class LayerOpValidation extends BaseOpValidation { } ); - TestCase tc = new TestCase(sd).gradientCheck(false).expectedOutput(res.getVarName(), expected); + TestCase tc = new TestCase(sd).gradientCheck(false).expectedOutput(res.name(), expected); String err = OpValidation.validate(tc); assertNull(err); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index 37c1a7086..7f6daf78f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -423,11 +423,11 @@ public class MiscOpValidation extends BaseOpValidation { } SDVariable loss = sd.sum(scatter); //.standardDeviation(scatter, true); //.sum(scatter); //TODO stdev might be better here as gradients are non-symmetrical... - sd.execAndEndResult(); + TestCase tc = new TestCase(sd) .expected(scatter, exp) - .gradCheckSkipVariables(indices.getVarName()); + .gradCheckSkipVariables(indices.name()); String error = OpValidation.validate(tc); if(error != null){ @@ -493,7 +493,7 @@ public class MiscOpValidation extends BaseOpValidation { TestCase tc = new TestCase(sd) .testName(msg) - .gradCheckSkipVariables(indices.getVarName()); + .gradCheckSkipVariables(indices.name()); if (gatherExp != null) { tc.expected(gather, gatherExp); @@ -589,16 +589,16 @@ public class MiscOpValidation extends BaseOpValidation { Map m = sameDiff.outputAll(null); Map gm = sameDiff.calculateGradients(null, m.keySet()); - SDVariable finalResult = sameDiff.grad(sum.getVarName()); + SDVariable finalResult = sameDiff.grad(sum.name()); - SDVariable cGrad = sameDiff.grad(varMulPre.getVarName()); + SDVariable cGrad = sameDiff.grad(varMulPre.name()); - SDVariable mulGradResult = sameDiff.grad(varMul.getVarName()); - SDVariable aGrad = sameDiff.grad(sdVariable.getVarName()); - SDVariable wGrad = sameDiff.grad(sdVariable1.getVarName()); - SDVariable dGrad = sameDiff.grad(varMul.getVarName()); + SDVariable mulGradResult = sameDiff.grad(varMul.name()); + SDVariable aGrad = sameDiff.grad(sdVariable.name()); + SDVariable wGrad = sameDiff.grad(sdVariable1.name()); + SDVariable dGrad = sameDiff.grad(varMul.name()); - INDArray scalarGradTest = gm.get(sum.getVarName()); + INDArray scalarGradTest = gm.get(sum.name()); assertEquals(scalar, scalarGradTest); @@ -738,11 +738,10 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable B2 = sd.var("B2", B); SDVariable[] batchMul = sd.batchMmul(new SDVariable[] {A1, A2}, new SDVariable[] {B1, B2}); - sd.exec(Collections.emptyMap(), sd.outputs()); - - INDArray resultingMatrix = batchMul[0].getArr(); - System.out.print(resultingMatrix); + Map m = sd.output(Collections.emptyMap(), sd.outputs()); + INDArray resultingMatrix = m.get(batchMul[0].name()); + //System.out.print(resultingMatrix); } @@ -770,14 +769,14 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable mmul = sd.f().mmul(f, s, mt); sd.updateVariableNameAndReference(mmul, "mmul"); - INDArray out = sd.execAndEndResult(); + INDArray out = mmul.eval(); INDArray exp = first.transpose().mmul(second); assertEquals(exp, out); SDVariable loss = sd.standardDeviation(mmul, true); String err = OpValidation.validate(new TestCase(sd) - .expected(mmul.getVarName(), exp)); + .expected(mmul.name(), exp)); assertNull(err); } @@ -1287,7 +1286,7 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable var = sd.var("in", i); SDVariable diag = sd.math().diagPart(var); - INDArray out = sd.execAndEndResult(); + INDArray out = diag.eval(); assertEquals(1, out.rank()); } @@ -1644,10 +1643,10 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable v = new StopGradient(sd, w).outputVariable(); SDVariable loss = v.std(true); - Map gm = sd.calculateGradients(null, v.getVarName(), w.getVarName()); + Map gm = sd.calculateGradients(null, v.name(), w.name()); - INDArray vArr = gm.get(v.getVarName()); - INDArray wArr = gm.get(w.getVarName()); + INDArray vArr = gm.get(v.name()); + INDArray wArr = gm.get(w.name()); System.out.println(vArr); System.out.println(wArr); @@ -1669,18 +1668,18 @@ public class MiscOpValidation extends BaseOpValidation { INDArray expLoss = in.std(true); String err = OpValidation.validate(new TestCase(sd) - .expectedOutput(checkNumerics.getVarName(), in) + .expectedOutput(checkNumerics.name(), in) .placeholderValue("in", in) .expectedOutput("loss", expLoss)); Preconditions.checkState(err == null, err); //Also check that it actually does what it's supposed to: - sd.execAll(Collections.singletonMap("in", in)); + sd.outputAll(Collections.singletonMap("in", in)); in.putScalar(0, Double.NaN); try { - sd.execAll(Collections.singletonMap("in", in)); + sd.outputAll(Collections.singletonMap("in", in)); fail("Expected exception"); } catch (Throwable t){ //OK @@ -1688,14 +1687,14 @@ public class MiscOpValidation extends BaseOpValidation { in.putScalar(0, Double.POSITIVE_INFINITY); try { - sd.execAll(Collections.singletonMap("in", in)); + sd.outputAll(Collections.singletonMap("in", in)); fail("Expected exception"); } catch (Throwable t){ //OK } in.putScalar(0, 0.0); - sd.execAll(Collections.singletonMap("in", in)); + sd.outputAll(Collections.singletonMap("in", in)); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java index 24ca6540b..802ed9be9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java @@ -117,8 +117,8 @@ public class ReductionOpValidation extends BaseOpValidation { SDVariable loss = nonZero.add(zero).castTo(DataType.DOUBLE).std(true); String error = OpValidation.validate(new TestCase(sd) - .expectedOutput(nonZero.getVarName(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 4.0)) - .expectedOutput(zero.getVarName(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 0.0)) + .expectedOutput(nonZero.name(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 4.0)) + .expectedOutput(zero.name(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 0.0)) .gradientCheck(false) ); if (error != null) @@ -148,7 +148,7 @@ public class ReductionOpValidation extends BaseOpValidation { SDVariable zeroFraction = sd.math().zeroFraction(input); String error = OpValidation.validate(new TestCase(sd) - .expectedOutput(zeroFraction.getVarName(), Nd4j.scalar(i == 0 ? 0.5f : 0.0f)) + .expectedOutput(zeroFraction.name(), Nd4j.scalar(i == 0 ? 0.5f : 0.0f)) .gradientCheck(i != 0) ); if (error != null) @@ -429,7 +429,7 @@ public class ReductionOpValidation extends BaseOpValidation { tc.gradientCheck(gradientCheckable); if(exp != null){ - tc.expectedOutput(loss.getVarName(), exp); + tc.expectedOutput(loss.name(), exp); } String error = OpValidation.validate(tc); @@ -996,7 +996,7 @@ public class ReductionOpValidation extends BaseOpValidation { String msg = name + " - dims=" + Arrays.toString(reduceDims); - INDArray out = sd.execAndEndResult(); + INDArray out = reduced.eval(); log.info(msg + " - expected shape: " + Arrays.toString(expShape) + ", out=" + Arrays.toString(out.shape()) + ", outExp=" + Arrays.toString(expOut.shape())); @@ -1069,10 +1069,10 @@ public class ReductionOpValidation extends BaseOpValidation { sd.associateArrayWithVariable(inputArr, input); sd.associateArrayWithVariable(labelArr, label); - INDArray result = sd.execAndEndResult(); + INDArray result = loss.eval(); assertEquals(1, result.length()); - sd.execBackwards(Collections.emptyMap()); + sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java index 8ecdc4eac..988b8da69 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java @@ -76,11 +76,11 @@ public class RnnOpValidation extends BaseOpValidation { LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y List toExec = new ArrayList<>(); for(SDVariable sdv : v.getAllOutputs()){ - toExec.add(sdv.getVarName()); + toExec.add(sdv.name()); } //Test forward pass: - Map m = sd.exec(null, toExec); + Map m = sd.output(null, toExec); //Weights and bias order: [i, f, z, o] @@ -179,11 +179,11 @@ public class RnnOpValidation extends BaseOpValidation { LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y List toExec = new ArrayList<>(); for(SDVariable sdv : v.getAllOutputs()){ - toExec.add(sdv.getVarName()); + toExec.add(sdv.name()); } //Test forward pass: - Map m = sd.exec(null, toExec); + Map m = sd.output(null, toExec); INDArray out0 = Nd4j.create(new float[]{0.27817473f, 0.53092605f}, new int[]{1,2}); //Input mod gate INDArray out1 = Nd4j.create(new float[]{-0.18100877f, 0.19417824f}, new int[]{1,2}); //CS (pre tanh) @@ -233,11 +233,11 @@ public class RnnOpValidation extends BaseOpValidation { List v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs(); List toExec = new ArrayList<>(); for(SDVariable sdv : v){ - toExec.add(sdv.getVarName()); + toExec.add(sdv.name()); } //Test forward pass: - Map m = sd.exec(null, toExec); + Map m = sd.output(null, toExec); //Weights and bias order: [r, u], [c] diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 101dcdeaf..401dddc56 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -128,7 +128,7 @@ public class ShapeOpValidation extends BaseOpValidation { //Using stdev here: mean/sum would backprop the same gradient for each input... SDVariable stdev = sd.standardDeviation("out", reshape, true); - INDArray out = sd.execAndEndResult(); + INDArray out = stdev.eval(); INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE); String msg = "toShape=" + Arrays.toString(toShape) + ", order=" + order; @@ -247,7 +247,7 @@ public class ShapeOpValidation extends BaseOpValidation { Map m = sd.outputAll(null); INDArray expOut = in.getArr().std(true); - assertArrayEquals(expExpandShape, m.get(expand.getVarName()).shape()); + assertArrayEquals(expExpandShape, m.get(expand.name()).shape()); INDArray expExpand = inArr.dup('c').reshape(expExpandShape); String msg = "expandDim=" + i + ", source=" + p.getSecond(); @@ -256,7 +256,7 @@ public class ShapeOpValidation extends BaseOpValidation { TestCase tc = new TestCase(sd); tc.testName(msg) .expectedOutput("out", expOut) - .expectedOutput(expand.getVarName(), expExpand); + .expectedOutput(expand.name(), expExpand); String error = OpValidation.validate(tc); if(error != null){ @@ -306,17 +306,17 @@ public class ShapeOpValidation extends BaseOpValidation { Map m = sd.outputAll(null); - INDArray squeezed = m.get(squeeze.getVarName()); + INDArray squeezed = m.get(squeeze.name()); // assertArrayEquals(expShapePostSqueeze, squeezed.shape()); - INDArray out = sd.execAndEndResult(); + INDArray out = m.get(stdev.name()); INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE); assertEquals(expOut, out); String msg = "squeezeDim=" + i + ", source=" + p.getSecond(); TestCase tc = new TestCase(sd) .testName(msg) - .expected(squeeze.getVarName(), exp) + .expected(squeeze.name(), exp) .expectedOutput("out", expOut); @@ -618,7 +618,7 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable stack = sd.stack(axis, in); - INDArray out = sd.execAndEndResult(); + INDArray out = stack.eval(); assertArrayEquals(expOutShape, out.shape()); if (ArrayUtil.prodLong(shape) == 1) { @@ -714,7 +714,7 @@ public class ShapeOpValidation extends BaseOpValidation { Map m = sd.outputAll(null); for (SDVariable v : unstacked) { - assertArrayEquals(msg, shape, m.get(v.getVarName()).shape()); + assertArrayEquals(msg, shape, m.get(v.name()).shape()); } TestCase tc = new TestCase(sd).testName(msg); @@ -884,7 +884,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray exp = arr.dup('c').reshape('c', 4,3); String err = OpValidation.validate(new TestCase(sameDiff) - .expectedOutput(result1.getVarName(), exp)); + .expectedOutput(result1.name(), exp)); assertNull(err); } @@ -920,7 +920,7 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable result = sameDiff.transpose(x); SDVariable loss = sameDiff.standardDeviation(result, true); - String err = OpValidation.validate(new TestCase(sameDiff).expectedOutput(result.getVarName(), arr.transpose())); + String err = OpValidation.validate(new TestCase(sameDiff).expectedOutput(result.name(), arr.transpose())); assertNull(err); } @@ -1022,17 +1022,16 @@ public class ShapeOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); INDArray ia = Nd4j.create(new double[]{1,2,3}); SDVariable in = sd.var(ia); - SDVariable constant = sd.constant(in, 3); - SDVariable loss = constant.std(true); + SDVariable loss = in.std(true); - assertNull(OpValidation.validate(new TestCase(sd).expected(constant, ia))); + assertNull(OpValidation.validate(new TestCase(sd).expected(in, ia))); //Case 1: shape is provided + scalar sd = SameDiff.create(); ia = Nd4j.scalar(3.0); in = sd.var(ia); - constant = sd.constant(in, 3,4,5); + SDVariable constant = sd.constant(Nd4j.create(DataType.FLOAT, 3,4,5)); INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0); loss = constant.std(true); @@ -1149,7 +1148,7 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable loss = sameDiff.standardDeviation(result, true); String err = OpValidation.validate(new TestCase(sameDiff) - .expected(result.getVarName(), expected) + .expected(result.name(), expected) .gradientCheck(false)); assertNull(err); } @@ -1172,7 +1171,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray outExp = Nd4j.scalar(d); String err = OpValidation.validate(new TestCase(sd) - .expected(md.getVarName(), outExp)); + .expected(md.name(), outExp)); assertNull(err); } @@ -1196,7 +1195,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray outExp = Nd4j.scalar(d); String err = OpValidation.validate(new TestCase(sd) - .expected(md.getVarName(), outExp)); + .expected(md.name(), outExp)); assertNull(err); } @@ -1227,7 +1226,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray outExp = Nd4j.scalar(d); String err = OpValidation.validate(new TestCase(sd) - .expected(md.getVarName(), outExp)); + .expected(md.name(), outExp)); assertNull(err); } @@ -1247,7 +1246,7 @@ public class ShapeOpValidation extends BaseOpValidation { //System.out.println(d); String err = OpValidation.validate(new TestCase(sd) - .expected(md.getVarName(), Nd4j.scalar(d))); + .expected(md.name(), Nd4j.scalar(d))); assertNull(err); } @@ -1332,7 +1331,7 @@ public class ShapeOpValidation extends BaseOpValidation { .testName(op) .expected(sm, exp) .gradientCheck(true) - .gradCheckSkipVariables(segments.getVarName()); + .gradCheckSkipVariables(segments.name()); String err = OpValidation.validate(tc); if(err != null) @@ -1383,7 +1382,7 @@ public class ShapeOpValidation extends BaseOpValidation { String err = OpValidation.validate(new TestCase(sameDiff) .expected(result1, expected) - .gradCheckSkipVariables(lengths.getVarName())); + .gradCheckSkipVariables(lengths.name())); assertNull(err); // Test with dynamic maxlen @@ -1591,8 +1590,8 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3)); SDVariable x = sameDiff.var("x", arr); SDVariable result = sameDiff.permute(x, 1, 0); - sameDiff.execAll(null); - assertArrayEquals(new long[]{3, 2}, result.getShape()); + Map m = sameDiff.outputAll(null); + assertArrayEquals(new long[]{3, 2}, m.get(result.name()).shape()); } @@ -1629,10 +1628,10 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable slice_full = sd.slice(in, new int[]{0, 0}, new int[]{3, 4}); SDVariable subPart = sd.slice(in, new int[]{1, 2}, new int[]{2, 2}); - sd.exec(Collections.emptyMap(), sd.outputs()); + Map m = sd.outputAll(Collections.emptyMap()); - assertEquals(inArr, slice_full.getArr()); - assertEquals(inArr.get(interval(1, 3), interval(2, 4)), subPart.getArr()); + assertEquals(inArr, m.get(slice_full.name())); + assertEquals(inArr.get(interval(1, 3), interval(2, 4)), m.get(subPart.name())); } @@ -1645,10 +1644,10 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable slice_full = sd.slice(in, new int[]{0, 0, 0}, new int[]{3, 4, 5}); SDVariable subPart = sd.slice(in, new int[]{1, 2, 3}, new int[]{2, 2, 1}); - sd.exec(Collections.emptyMap(), sd.outputs()); + Map m = sd.outputAll(null); - assertEquals(inArr, slice_full.getArr()); - assertEquals(inArr.get(interval(1, 3), interval(2, 4), interval(3, 4)), subPart.getArr()); + assertEquals(inArr, m.get(slice_full.name())); + assertEquals(inArr.get(interval(1, 3), interval(2, 4), interval(3, 4)), m.get(subPart.name())); } @Test @@ -1661,7 +1660,7 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable subPart = sd.stridedSlice(in, new int[]{1, 2}, new int[]{3, 4}, new int[]{1, 1}); // SDVariable subPart2 = sd.stridedSlice(in, new int[]{0, 0}, new int[]{4, 5}, new int[]{2, 2}); - sd.execAll(null); + sd.outputAll(null); assertEquals(inArr, slice_full.getArr()); assertEquals(inArr.get(interval(1, 3), interval(2, 4)), subPart.getArr()); @@ -1678,7 +1677,7 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable slice1 = sd.stridedSlice(in, new int[]{-999, 0}, new int[]{2, 4}, new int[]{1, 1}, 1 << 1, 0, 0, 0, 0); SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 0}, new int[]{-999, 4}, new int[]{1, 1}, 0, 1, 0, 0, 0); - sd.execAll(null); + sd.outputAll(null); assertEquals(inArr.get(NDArrayIndex.interval(0, 2), NDArrayIndex.all()), slice1.getArr()); assertEquals(inArr.get(NDArrayIndex.interval(1, 3), NDArrayIndex.all()), slice2.getArr()); @@ -1695,7 +1694,7 @@ public class ShapeOpValidation extends BaseOpValidation { //[1:3,...,1:4] -> [1:3,:,1:4] SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 1}, new int[]{3, 4}, new int[]{1, 1}, 0, 0, 1 << 1, 0, 0); - sd.execAll(Collections.emptyMap()); + sd.outputAll(Collections.emptyMap()); assertEquals(inArr.get(interval(1, 3), all(), all()), slice.getArr()); assertEquals(inArr.get(interval(1, 3), all(), all()), slice2.getArr()); @@ -1708,7 +1707,7 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable in = sd.var("in", inArr); SDVariable slice = sd.stridedSlice(in, new int[]{-999, 0, 0, 0}, new int[]{-999, 3, 4, 5}, new int[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0); - INDArray out = sd.execAndEndResult(); + INDArray out = slice.eval(); assertArrayEquals(new long[]{1, 3, 4, 5}, out.shape()); assertEquals(inArr, out.get(point(0), all(), all(), all())); @@ -1720,7 +1719,7 @@ public class ShapeOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); SDVariable slice = sd.stridedSlice(in, new int[]{1, 1, -999, 1}, new int[]{3, 3, -999, 4}, new int[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0); - INDArray out = sd.execAndEndResult(); + INDArray out = slice.eval(); assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape()); } @@ -1735,7 +1734,7 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable slice2 = sd.stridedSlice(in, new int[]{2, 0, 0}, new int[]{-999, 4, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1); SDVariable slice3 = sd.stridedSlice(in, new int[]{1, 2, 1}, new int[]{-999, -999, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1); - sd.execAll(null); + sd.outputAll(null); assertEquals(inArr.get(point(0), all(), all()), slice.getArr()); assertEquals(inArr.get(point(2), all(), all()), slice2.getArr()); @@ -1880,8 +1879,8 @@ public class ShapeOpValidation extends BaseOpValidation { // log.info(sd.summary()); - sd.exec(Collections.emptyMap(), Lists.newArrayList(s)); - sd.execBackwards(Collections.emptyMap()); + sd.output(Collections.emptyMap(), Lists.newArrayList(s)); + sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); } } @@ -2405,8 +2404,8 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable gathered = sd.gather(input, indices, 1); SDVariable loss = gathered.std(true); - sd.exec(null, gathered.getVarName()); - sd.setLossVariables(gathered.getVarName()); + sd.output((Map)null, gathered.name()); + sd.setLossVariables(gathered.name()); String err = OpValidation.validate(new TestCase(sd) .gradCheckEpsilon(1e-3) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index fdd2b3160..bb3bab213 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -115,37 +115,37 @@ public class TransformOpValidation extends BaseOpValidation { switch (i){ case 0: out = in.mul(2); - tc.expectedOutput(out.getVarName(), inArr.mul(2)); + tc.expectedOutput(out.name(), inArr.mul(2)); msg = "mul - " + inOrder; break; case 1: out = in.div(2); - tc.expectedOutput(out.getVarName(), inArr.div(2)); + tc.expectedOutput(out.name(), inArr.div(2)); msg = "div - " + inOrder; break; case 2: out = in.add(2); - tc.expectedOutput(out.getVarName(), inArr.add(2)); + tc.expectedOutput(out.name(), inArr.add(2)); msg = "add - " + inOrder; break; case 3: out = in.sub(2); - tc.expectedOutput(out.getVarName(), inArr.sub(2)); + tc.expectedOutput(out.name(), inArr.sub(2)); msg = "sub - " + inOrder; break; case 4: out = in.rdiv(2); - tc.expectedOutput(out.getVarName(), inArr.rdiv(2)); + tc.expectedOutput(out.name(), inArr.rdiv(2)); msg = "rdiv - " + inOrder; break; case 5: out = in.rsub(2); - tc.expectedOutput(out.getVarName(), inArr.rsub(2)); + tc.expectedOutput(out.name(), inArr.rsub(2)); msg = "rsub - " + inOrder; break; case 6: out = sd.math().pow(in,2); - tc.expectedOutput(out.getVarName(), Transforms.pow(inArr, 2)); + tc.expectedOutput(out.name(), Transforms.pow(inArr, 2)); msg = "pow - " + inOrder; break; case 7: @@ -584,219 +584,219 @@ public class TransformOpValidation extends BaseOpValidation { switch (i) { case 0: t = in.add(5.0); - tc.expectedOutput(t.getVarName(), ia.add(5.0)); + tc.expectedOutput(t.name(), ia.add(5.0)); break; case 1: t = in.sub(5.0); - tc.expectedOutput(t.getVarName(), ia.sub(5.0)); + tc.expectedOutput(t.name(), ia.sub(5.0)); break; case 2: t = in.mul(2.5); - tc.expectedOutput(t.getVarName(), ia.mul(2.5)); + tc.expectedOutput(t.name(), ia.mul(2.5)); break; case 3: t = in.div(4.0); - tc.expectedOutput(t.getVarName(), ia.div(4.0)); + tc.expectedOutput(t.name(), ia.div(4.0)); break; case 4: t = in.rsub(5.0); - tc.expectedOutput(t.getVarName(), ia.rsub(5.0)); + tc.expectedOutput(t.name(), ia.rsub(5.0)); break; case 5: t = in.rdiv(1.0); - tc.expectedOutput(t.getVarName(), ia.rdiv(1.0)); + tc.expectedOutput(t.name(), ia.rdiv(1.0)); break; case 6: t = sd.math().pow(in, 2.5); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); - tc.expectedOutput(t.getVarName(), Transforms.pow(ia, 2.5, true)); + tc.expectedOutput(t.name(), Transforms.pow(ia, 2.5, true)); break; case 7: t = sd.nn().sigmoid(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(2).subi(1.0); - tc.expectedOutput(t.getVarName(), Transforms.sigmoid(ia, true)); + tc.expectedOutput(t.name(), Transforms.sigmoid(ia, true)); break; case 8: t = sd.math().tanh(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(2).subi(1.0); - tc.expectedOutput(t.getVarName(), Transforms.tanh(ia, true)); + tc.expectedOutput(t.name(), Transforms.tanh(ia, true)); break; case 9: ia.assign(Nd4j.rand(DataType.DOUBLE, ia.shape())); t = sd.math().tan(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); - tc.expectedOutput(t.getVarName(), Transforms.tan(ia)); + tc.expectedOutput(t.name(), Transforms.tan(ia)); break; case 10: t = sd.math().cos(in); - tc.expectedOutput(t.getVarName(), Transforms.cos(ia, true)); + tc.expectedOutput(t.name(), Transforms.cos(ia, true)); break; case 11: t = sd.math().sin(in); - tc.expectedOutput(t.getVarName(), Transforms.sin(ia, true)); + tc.expectedOutput(t.name(), Transforms.sin(ia, true)); break; case 12: t = sd.nn().softplus(in); - tc.expectedOutput(t.getVarName(), Transforms.softPlus(ia, true)); + tc.expectedOutput(t.name(), Transforms.softPlus(ia, true)); break; case 13: t = sd.math().log(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); - tc.expectedOutput(t.getVarName(), Transforms.log(ia, true)); + tc.expectedOutput(t.name(), Transforms.log(ia, true)); break; case 14: t = sd.math().neg(in); INDArray exp14 = ia.neg(); - tc.expectedOutput(t.getVarName(), exp14); + tc.expectedOutput(t.name(), exp14); break; case 15: t = sd.math().acos(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9); - tc.expectedOutput(t.getVarName(), Transforms.acos(ia, true)); + tc.expectedOutput(t.name(), Transforms.acos(ia, true)); break; case 16: t = sd.math().acosh(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).addi(1.01); //Only defined for x >= 1 - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new ACosh(ia.dup()))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ACosh(ia.dup()))); break; case 17: t = sd.math().asin(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9); - tc.expectedOutput(t.getVarName(), Transforms.asin(ia, true)); + tc.expectedOutput(t.name(), Transforms.asin(ia, true)); break; case 18: t = sd.math().atan(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(4).subi(2); - tc.expectedOutput(t.getVarName(), Transforms.atan(ia, true)); + tc.expectedOutput(t.name(), Transforms.atan(ia, true)); break; case 19: t = sd.math().atanh(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9); - tc.expectedOutput(t.getVarName(), Transforms.atanh(ia, true)); + tc.expectedOutput(t.name(), Transforms.atanh(ia, true)); break; case 20: t = sd.math().cosh(in); - tc.expectedOutput(t.getVarName(), Transforms.cosh(ia, true)); + tc.expectedOutput(t.name(), Transforms.cosh(ia, true)); break; case 21: t = sd.math().cube(in); - tc.expectedOutput(t.getVarName(), Transforms.pow(ia, 3.0, true)); + tc.expectedOutput(t.name(), Transforms.pow(ia, 3.0, true)); break; case 22: t = sd.nn().elu(in); - tc.expectedOutput(t.getVarName(), Transforms.elu(ia, true)); + tc.expectedOutput(t.name(), Transforms.elu(ia, true)); break; case 23: //TODO SHOULDN'T THIS HAVE A DIMENSION ARG??? t = sd.nn().softmax(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new SoftMax(ia.dup()))[0]); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new SoftMax(ia.dup()))[0]); break; case 24: t = sd.math().sqrt(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); - tc.expectedOutput(t.getVarName(), Transforms.sqrt(ia, true)); + tc.expectedOutput(t.name(), Transforms.sqrt(ia, true)); break; case 25: t = sd.math().square(in); - tc.expectedOutput(t.getVarName(), Transforms.pow(ia, 2.0, true)); + tc.expectedOutput(t.name(), Transforms.pow(ia, 2.0, true)); break; case 26: t = sd.transpose(in); - tc.expectedOutput(t.getVarName(), ia.transpose().dup()); + tc.expectedOutput(t.name(), ia.transpose().dup()); break; case 27: t = sd.math().abs(in); - tc.expectedOutput(t.getVarName(), Transforms.abs(ia, true)); + tc.expectedOutput(t.name(), Transforms.abs(ia, true)); break; case 28: t = sd.math().sinh(in); - tc.expectedOutput(t.getVarName(), Transforms.sinh(ia, true)); + tc.expectedOutput(t.name(), Transforms.sinh(ia, true)); break; case 29: t = sd.math().asinh(in); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new ASinh(ia.dup()))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ASinh(ia.dup()))); break; case 30: t = sd.math().exp(in); - tc.expectedOutput(t.getVarName(), Transforms.exp(ia, true)); + tc.expectedOutput(t.name(), Transforms.exp(ia, true)); break; case 31: t = sd.math().floor(in); - tc.expectedOutput(t.getVarName(), Transforms.floor(ia, true)); + tc.expectedOutput(t.name(), Transforms.floor(ia, true)); break; case 32: t = sd.nn().relu(in, 0.0); ia = Nd4j.rand(minibatch, nOut); - tc.expectedOutput(t.getVarName(), Transforms.relu(ia, true)); + tc.expectedOutput(t.name(), Transforms.relu(ia, true)); break; case 33: t = sd.nn().hardTanh(in); ia = Nd4j.rand(minibatch, nOut).muli(2).subi(1.0); - tc.expectedOutput(t.getVarName(), Transforms.hardTanh(ia, true)); + tc.expectedOutput(t.name(), Transforms.hardTanh(ia, true)); break; case 34: t = sd.nn().logSigmoid(in); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new LogSigmoid(ia.dup()))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new LogSigmoid(ia.dup()))); break; case 35: t = sd.nn().swish(in); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Swish(ia.dup()))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Swish(ia.dup()))); break; case 36: t = sd.math().sign(in); - tc.expectedOutput(t.getVarName(), Transforms.sign(ia, true)); + tc.expectedOutput(t.name(), Transforms.sign(ia, true)); break; case 37: t = sd.nn().softsign(in); - tc.expectedOutput(t.getVarName(), Transforms.softsign(ia, true)); + tc.expectedOutput(t.name(), Transforms.softsign(ia, true)); break; case 38: t = sd.nn().leakyRelu(in, 0.0); ia = Nd4j.rand(minibatch, nOut); - tc.expectedOutput(t.getVarName(), Transforms.leakyRelu(ia, true)); + tc.expectedOutput(t.name(), Transforms.leakyRelu(ia, true)); break; case 39: if(OpValidationSuite.IGNORE_FAILING) continue; t = sd.nn().logSoftmax(in); ia = Nd4j.rand(minibatch, nOut).muli(10).subi(5); - tc.expectedOutput(t.getVarName(), Transforms.log(Transforms.softmax(ia, true))); + tc.expectedOutput(t.name(), Transforms.log(Transforms.softmax(ia, true))); stdevLoss = true; break; case 40: t = sd.nn().selu(in); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new SELU(ia.dup()))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new SELU(ia.dup()))); break; case 41: t = sd.gt(in, 1.0).castTo(DataType.DOUBLE); - tc.expectedOutput(t.getVarName(), ia.gt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.gt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); break; case 42: t = sd.gte(in, 1.0).castTo(DataType.DOUBLE); - tc.expectedOutput(t.getVarName(), ia.gte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.gte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); break; case 43: t = sd.lt(in, 1.0).castTo(DataType.DOUBLE); - tc.expectedOutput(t.getVarName(), ia.lt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.lt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); break; case 44: t = sd.lte(in, 1.0).castTo(DataType.DOUBLE); - tc.expectedOutput(t.getVarName(), ia.lte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.lte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); break; case 45: t = sd.eq(in, 2.0).castTo(DataType.DOUBLE); ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut, DataType.DOUBLE).reshape('c', minibatch, nOut); - tc.expectedOutput(t.getVarName(), ia.eq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.eq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false); break; case 46: t = sd.neq(in, 2.0).castTo(DataType.DOUBLE); ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut, DataType.DOUBLE).reshape('c', minibatch, nOut); - tc.expectedOutput(t.getVarName(), ia.neq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.neq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false); break; case 47: t = sd.math().ceil(in); - tc.expectedOutput(t.getVarName(), Transforms.ceil(ia, true)); + tc.expectedOutput(t.name(), Transforms.ceil(ia, true)); break; case 48: ia = Nd4j.randn(DataType.DOUBLE, ia.shape()).muli(2); @@ -804,7 +804,7 @@ public class TransformOpValidation extends BaseOpValidation { INDArray expOut48 = ia.dup(); BooleanIndexing.replaceWhere(expOut48, -3, Conditions.lessThan(-3)); BooleanIndexing.replaceWhere(expOut48, 2, Conditions.greaterThan(2)); - tc.expectedOutput(t.getVarName(), expOut48); + tc.expectedOutput(t.name(), expOut48); break; case 49: //Clip by norm, dimension 0, some below threshold, some above @@ -825,7 +825,7 @@ public class TransformOpValidation extends BaseOpValidation { expOut49.putColumn(j, origCol.mul(clip / origCol.norm2Number().doubleValue())); } } - tc.expectedOutput(t.getVarName(), expOut49); + tc.expectedOutput(t.name(), expOut49); //System.out.println(expOut.norm2(0)); break; //TODO clip by norm along other dimensions @@ -837,7 +837,7 @@ public class TransformOpValidation extends BaseOpValidation { .addIntegerArguments(dim) .addInputs(ia).addOutputs(expOut50).build(); Nd4j.getExecutioner().exec(reverse); - tc.expectedOutput(t.getVarName(), expOut50); + tc.expectedOutput(t.name(), expOut50); break; case 51: dim = 0; @@ -850,7 +850,7 @@ public class TransformOpValidation extends BaseOpValidation { .addIntegerArguments((exclusive) ? 1 : 0, (reverseBool) ? 1 : 0, dim) .addInputs(ia).addOutputs(expOut51).build(); Nd4j.getExecutioner().exec(cumsum); - tc.expectedOutput(t.getVarName(), expOut51); + tc.expectedOutput(t.name(), expOut51); break; case 52: if(OpValidationSuite.IGNORE_FAILING){ @@ -869,7 +869,7 @@ public class TransformOpValidation extends BaseOpValidation { expOut52.putScalar(s0, s1, prod); } } - tc.expectedOutput(t.getVarName(), expOut52); + tc.expectedOutput(t.name(), expOut52); break; case 53: if(OpValidationSuite.IGNORE_FAILING){ @@ -881,90 +881,90 @@ public class TransformOpValidation extends BaseOpValidation { INDArray expOut53 = Nd4j.create(DataType.DOUBLE, 2, 2); DynamicCustomOp op = DynamicCustomOp.builder("diag").addInputs(ia).addOutputs(expOut53).build(); Nd4j.getExecutioner().exec(op); - tc.expectedOutput(t.getVarName(), expOut53); + tc.expectedOutput(t.name(), expOut53); break; case 54: t = sd.math().erf(in); INDArray expOut54 = Nd4j.createUninitialized(DataType.DOUBLE, ia.shape(), ia.ordering()); Nd4j.getExecutioner().exec(new Erf(ia, expOut54)); - tc.expectedOutput(t.getVarName(), expOut54); + tc.expectedOutput(t.name(), expOut54); break; case 55: t = sd.math().erfc(in); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Erfc(ia, Nd4j.createUninitialized(ia.shape(), ia.ordering())))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Erfc(ia, Nd4j.createUninitialized(ia.shape(), ia.ordering())))); break; case 56: t = sd.math().expm1(in); - tc.expectedOutput(t.getVarName(),Transforms.expm1(ia, true)); + tc.expectedOutput(t.name(),Transforms.expm1(ia, true)); break; case 57: t = sd.math().log1p(in); ia = Nd4j.rand(minibatch, nOut); - tc.expectedOutput(t.getVarName(), Transforms.log1p(ia, true)); + tc.expectedOutput(t.name(), Transforms.log1p(ia, true)); break; case 58: t = sd.math().round(in); - tc.expectedOutput(t.getVarName(), Transforms.round(ia, true)); + tc.expectedOutput(t.name(), Transforms.round(ia, true)); break; case 59: ia = Nd4j.create(new float[]{4, 2}).castTo(DataType.DOUBLE); // in = sd.var("in", new int[]{1, 2}); t = sd.math().rsqrt(in); - tc.expectedOutput(t.getVarName(),Nd4j.getExecutioner().exec(new RSqrt(ia, Nd4j.create(ia.shape(), ia.ordering())))); + tc.expectedOutput(t.name(),Nd4j.getExecutioner().exec(new RSqrt(ia, Nd4j.create(ia.shape(), ia.ordering())))); break; case 60: t = sd.nn().relu6(in, 0); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); - tc.expectedOutput(t.getVarName(),Transforms.relu6(ia, true)); + tc.expectedOutput(t.name(),Transforms.relu6(ia, true)); break; case 61: ia = Nd4j.create(new float[] {2, 2}).castTo(DataType.DOUBLE); sd.associateArrayWithVariable(ia, in); double value = 42; t = sd.fill(in.castTo(DataType.INT), DataType.DOUBLE, value); - tc.expectedOutput(t.getVarName(), Nd4j.valueArrayOf(new int[]{2,2}, 42)).gradientCheck(false); + tc.expectedOutput(t.name(), Nd4j.valueArrayOf(new int[]{2,2}, 42)).gradientCheck(false); opName = "fill"; break; case 62: t = sd.nn().hardSigmoid(in); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new HardSigmoid(ia, ia.dup()))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new HardSigmoid(ia, ia.dup()))); break; case 63: t = sd.scalarMax(in, 0.5); - tc.expectedOutput(t.getVarName(), Transforms.max(ia, 0.5, true)); + tc.expectedOutput(t.name(), Transforms.max(ia, 0.5, true)); break; case 64: t = sd.scalarMin(in, 0.5); - tc.expectedOutput(t.getVarName(), Transforms.min(ia, 0.5, true)); + tc.expectedOutput(t.name(), Transforms.min(ia, 0.5, true)); break; case 65: t = sd.assign(in, 0.5); - tc.expectedOutput(t.getVarName(), ia.dup().assign(0.5)); + tc.expectedOutput(t.name(), ia.dup().assign(0.5)); break; case 66: t = sd.scalarFloorMod(in, 0.5); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new ScalarFMod(ia.dup(), 0.5))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ScalarFMod(ia.dup(), 0.5))); break; case 67: t = sd.math().reciprocal(in); - tc.expectedOutput(t.getVarName(), ia.rdiv(1.0)); + tc.expectedOutput(t.name(), ia.rdiv(1.0)); break; case 68: t = sd.shape(in).castTo(DataType.DOUBLE); - tc.expectedOutput(t.getVarName(), Nd4j.create(ArrayUtil.toDouble(ia.shape()))).gradientCheck(false); + tc.expectedOutput(t.name(), Nd4j.create(ArrayUtil.toDouble(ia.shape()))).gradientCheck(false); break; case 69: t = sd.rank(in).castTo(DataType.DOUBLE); - tc.expectedOutput(t.getVarName(), Nd4j.scalar((double)ia.rank())).gradientCheck(false); + tc.expectedOutput(t.name(), Nd4j.scalar((double)ia.rank())).gradientCheck(false); break; case 70: t = sd.onesLike(in); - tc.expectedOutput(t.getVarName(), Nd4j.ones(ia.shape())); + tc.expectedOutput(t.name(), Nd4j.ones(ia.shape())); break; case 71: ia = Nd4j.randn(DataType.DOUBLE, nOut, nOut); t = sd.math().diagPart(in); - tc.expectedOutput(t.getVarName(), Nd4j.create(new double[]{ia.getDouble(0,0), ia.getDouble(1,1), ia.getDouble(2,2), ia.getDouble(3,3)}).castTo(DataType.DOUBLE)); + tc.expectedOutput(t.name(), Nd4j.create(new double[]{ia.getDouble(0,0), ia.getDouble(1,1), ia.getDouble(2,2), ia.getDouble(3,3)}).castTo(DataType.DOUBLE)); break; case 72: t = sd.identity(in); @@ -1087,109 +1087,109 @@ public class TransformOpValidation extends BaseOpValidation { switch (i) { case 0: t = in1.add(in2); - tc.expectedOutput(t.getVarName(), ia.add(ib)); + tc.expectedOutput(t.name(), ia.add(ib)); break; case 1: t = in1.sub(in2); - tc.expectedOutput(t.getVarName(),ia.sub(ib)); + tc.expectedOutput(t.name(),ia.sub(ib)); break; case 2: t = in1.mul(in2); - tc.expectedOutput(t.getVarName(), ia.mul(ib)); + tc.expectedOutput(t.name(), ia.mul(ib)); break; case 3: t = in1.div(in2); - tc.expectedOutput(t.getVarName(), ia.div(ib)); + tc.expectedOutput(t.name(), ia.div(ib)); break; case 4: t = in1.rsub(in2); - tc.expectedOutput(t.getVarName(), ia.rsub(ib)); + tc.expectedOutput(t.name(), ia.rsub(ib)); break; case 5: ia.assign(Nd4j.rand(ia.shape())).addi(0.5); ib.assign(Nd4j.rand(ib.shape())).addi(0.5); t = in1.rdiv(in2); - tc.expectedOutput(t.getVarName(), ia.rdiv(ib)); + tc.expectedOutput(t.name(), ia.rdiv(ib)); break; case 6: t = sd.eq(in1, in2); opName = "eq"; - tc.expectedOutput(t.getVarName(), ia.eq(ib)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.eq(ib)).gradientCheck(false); break; case 7: t = sd.neq(in1, in2); opName = "neq"; - tc.expectedOutput(t.getVarName(), ia.neq(ib)).gradientCheck(false);; + tc.expectedOutput(t.name(), ia.neq(ib)).gradientCheck(false);; break; case 8: t = sd.gt(in1, in2); opName = "gt"; - tc.expectedOutput(t.getVarName(), ia.gt(ib)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.gt(ib)).gradientCheck(false); break; case 9: t = sd.lt(in1, in2); opName = "lt"; - tc.expectedOutput(t.getVarName(), ia.lt(ib)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.lt(ib)).gradientCheck(false); break; case 10: t = sd.gte(in1, in2); opName = "gte"; INDArray expOut10 = Nd4j.create(DataType.BOOL, ia.shape()); Nd4j.getExecutioner().exec(new GreaterThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut10})); - tc.expectedOutput(t.getVarName(), expOut10).gradientCheck(false); + tc.expectedOutput(t.name(), expOut10).gradientCheck(false); break; case 11: t = sd.lte(in1, in2); opName = "lte"; INDArray expOut11 = Nd4j.create(DataType.BOOL, ia.shape()); Nd4j.getExecutioner().exec(new LessThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut11})); - tc.expectedOutput(t.getVarName(), expOut11).gradientCheck(false); + tc.expectedOutput(t.name(), expOut11).gradientCheck(false); break; case 12: ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); t = sd.math().or(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); opName = "or"; - tc.expectedOutput(t.getVarName(), Transforms.or(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); + tc.expectedOutput(t.name(), Transforms.or(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); break; case 13: ib = Nd4j.randn(DataType.DOUBLE, nOut, nOut); t = sd.mmul(in1, in2); - tc.expectedOutput(t.getVarName(), ia.mmul(ib)); + tc.expectedOutput(t.name(), ia.mmul(ib)); break; case 14: t = sd.max(in1, in2); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0]); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0]); break; case 15: t = sd.min(in1, in2); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0]); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0]); break; case 16: ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); t = sd.math().and(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); opName = "and"; - tc.expectedOutput(t.getVarName(), Transforms.and(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); + tc.expectedOutput(t.name(), Transforms.and(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); break; case 17: ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); t = sd.math().xor(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); opName = "xor"; - tc.expectedOutput(t.getVarName(), Transforms.xor(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); + tc.expectedOutput(t.name(), Transforms.xor(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); break; case 18: t = sd.assign(in1, in2); - tc.expectedOutput(t.getVarName(), ib); + tc.expectedOutput(t.name(), ib); break; case 19: t = sd.math().atan2(in1, in2); - tc.expectedOutput(t.getVarName(), Transforms.atan2(ib, ia)); //Note: y,x order for samediff; x,y order for transforms + tc.expectedOutput(t.name(), Transforms.atan2(ib, ia)); //Note: y,x order for samediff; x,y order for transforms break; case 20: t = sd.math().mergeAdd(in1, in2, in2); - tc.expectedOutput(t.getVarName(), ia.add(ib).add(ib)); + tc.expectedOutput(t.name(), ia.add(ib).add(ib)); break; case 21: t = in1.squaredDifference(in2); @@ -1199,7 +1199,7 @@ public class TransformOpValidation extends BaseOpValidation { .addOutputs(expOut21) .build(); Nd4j.getExecutioner().exec(squareDiff); - tc.expectedOutput(t.getVarName(), expOut21); + tc.expectedOutput(t.name(), expOut21); break; case 22: //set diag @@ -1210,7 +1210,7 @@ public class TransformOpValidation extends BaseOpValidation { expOut22.putScalar(j,j, ib.getDouble(j)); } t = sd.math().setDiag(in1, in2); - tc.expectedOutput(t.getVarName(), expOut22); + tc.expectedOutput(t.name(), expOut22); break; default: throw new RuntimeException(); @@ -1341,7 +1341,6 @@ public class TransformOpValidation extends BaseOpValidation { } } - //TODO UPDATE TO OP VALIDATION OR DELETE @Test public void testLogGrad() { SameDiff sameDiff = SameDiff.create(); @@ -1349,7 +1348,7 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable log = sameDiff.math().log(input); SDVariable sum = sameDiff.sum(log, Integer.MAX_VALUE); INDArray result = null; - sameDiff.execBackwards(Collections.emptyMap()); + sameDiff.calculateGradients(Collections.emptyMap(), sameDiff.getVariables().keySet()); } @@ -1362,8 +1361,8 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable input = sameDiff.var("x", inputs.get("x")); SDVariable sigmoid = sameDiff.nn().sigmoid(input); SDVariable sum = sameDiff.sum(sigmoid, Integer.MAX_VALUE); - sameDiff.execBackwards(Collections.emptyMap()); - INDArray arr = input.gradient().getArr(); + Map m = sameDiff.calculateGradients(Collections.emptyMap(), sameDiff.getVariables().keySet()); + INDArray arr = m.get(input.name()); assertTrue(Nd4j.create(new double[][]{ {0.1966, 0.1050}, {0.0452, 0.0177} @@ -1384,12 +1383,12 @@ public class TransformOpValidation extends BaseOpValidation { public void testRank0EdgeCase(){ SameDiff sd = SameDiff.create(); SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))); - double d0 = sd.execAndEndResult().getDouble(0); + double d0 = v1.eval().getDouble(0); assertEquals(8, d0, 0); SDVariable v2 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))).div(2.0); - sd.exec(Collections.emptyMap(), sd.outputs()); - double d1 = v2.getArr().getDouble(0); + Map m = sd.outputAll(Collections.emptyMap()); + double d1 = m.get(v2.name()).getDouble(0); assertEquals(4, d1, 0); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java index 308ef0fd4..9d89aec91 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java @@ -87,12 +87,12 @@ public class FailingSameDiffTests extends BaseNd4jTest { SDVariable tanh = sd.math().tanh(in); INDArray exp = Transforms.tanh(in.getArr(), true); - INDArray out = sd.execAndEndResult(); + INDArray out = tanh.eval(); assertEquals(exp, out); //Now, replace with minibatch 5: in.setArray(Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(5,4)); - INDArray out2 = sd.execAndEndResult(); + INDArray out2 = tanh.eval(); assertArrayEquals(new long[]{5,4}, out2.shape()); exp = Transforms.tanh(in.getArr(), true); @@ -124,12 +124,12 @@ public class FailingSameDiffTests extends BaseNd4jTest { SDVariable mmul = sd.mmul(in,w).add(b); INDArray exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr()); - INDArray out = sd.execAndEndResult(); + INDArray out = mmul.eval(); assertEquals(exp, out); //Now, replace with minibatch 5: in.setArray(Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(5,4)); - INDArray out2 = sd.execAndEndResult(); + INDArray out2 = mmul.eval(); assertArrayEquals(new long[]{5,5}, out2.shape()); exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr()); @@ -137,11 +137,10 @@ public class FailingSameDiffTests extends BaseNd4jTest { //Generate gradient function, and exec SDVariable loss = mmul.std(true); - sd.execBackwards(Collections.emptyMap()); + sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); in.setArray(Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4)); - sd.execAndEndResult(); - out2 = mmul.getArr(); + out2 = mmul.eval(); assertArrayEquals(new long[]{3,5}, out2.shape()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java index a3ee34570..e704c9337 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java @@ -173,7 +173,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { } if(execFirst){ - sd.exec(Collections.singletonMap("in", arr), Collections.singletonList(x.getVarName())); + sd.output(Collections.singletonMap("in", arr), Collections.singletonList(x.name())); } File f = testDir.newFile(); @@ -186,7 +186,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { List varsRestored = restored.variables(); assertEquals(varsOrig.size(), varsRestored.size()); for (int j = 0; j < varsOrig.size(); j++) { - assertEquals(varsOrig.get(j).getVarName(), varsRestored.get(j).getVarName()); + assertEquals(varsOrig.get(j).name(), varsRestored.get(j).name()); } DifferentialFunction[] fOrig = sd.ops(); @@ -200,10 +200,10 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { assertEquals(sd.getLossVariables(), restored.getLossVariables()); - Map m = sd.exec(Collections.singletonMap("in", arr), Collections.singletonList(x.getVarName())); - INDArray outOrig = m.get(x.getVarName()); - Map m2 = restored.exec(Collections.singletonMap("in", arr), Collections.singletonList(x.getVarName())); - INDArray outRestored = m2.get(x.getVarName()); + Map m = sd.output(Collections.singletonMap("in", arr), Collections.singletonList(x.name())); + INDArray outOrig = m.get(x.name()); + Map m2 = restored.output(Collections.singletonMap("in", arr), Collections.singletonList(x.name())); + INDArray outRestored = m2.get(x.name()); assertEquals(String.valueOf(i), outOrig, outRestored); @@ -320,7 +320,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { if(v.isPlaceHolder() || v.getVariableType() == VariableType.ARRAY) continue; - SDVariable v2 = sd2.getVariable(v.getVarName()); + SDVariable v2 = sd2.getVariable(v.name()); INDArray a1 = v.getArr(); INDArray a2 = v2.getArr(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java index 083fdbfe8..b3b82e16c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java @@ -57,17 +57,17 @@ public class GraphTransformUtilTests extends BaseNd4jTest { SDVariable sub = add.sub(add2); - assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add.getVarName()))); - assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add2.getVarName()))); - assertFalse(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(sub.getVarName()))); + assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add.name()))); + assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add2.name()))); + assertFalse(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(sub.name()))); - assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add.getVarName()))); - assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add2.getVarName()))); - assertFalse(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(sub.getVarName()))); + assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add.name()))); + assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add2.name()))); + assertFalse(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(sub.name()))); - assertTrue(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(add.getVarName()))); - assertTrue(OpPredicate.opNameMatches("ad.*").matches(sd, sd.getVariableOutputOp(add2.getVarName()))); - assertFalse(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(sub.getVarName()))); + assertTrue(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(add.name()))); + assertTrue(OpPredicate.opNameMatches("ad.*").matches(sd, sd.getVariableOutputOp(add2.name()))); + assertFalse(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(sub.name()))); SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.classEquals(AddOp.class)); @@ -76,11 +76,11 @@ public class GraphTransformUtilTests extends BaseNd4jTest { assertEquals(2, l.size()); SubGraph sg1 = l.get(0); - assertTrue(sg1.getRootNode() == sd.getVariableOutputOp(add.getVarName())); + assertTrue(sg1.getRootNode() == sd.getVariableOutputOp(add.name())); assertEquals(0, sg1.getChildNodes().size()); SubGraph sg2 = l.get(1); - assertTrue(sg2.getRootNode() == sd.getVariableOutputOp(add2.getVarName())); + assertTrue(sg2.getRootNode() == sd.getVariableOutputOp(add2.name())); assertEquals(0, sg2.getChildNodes().size()); } @@ -118,7 +118,7 @@ public class GraphTransformUtilTests extends BaseNd4jTest { }); INDArray exp2 = p1.div(p2).mul(p1.sub(p2)); - INDArray out2 = sd2.getVariable(mul.getVarName()).eval(); + INDArray out2 = sd2.getVariable(mul.name()).eval(); assertEquals(exp2, out2); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java index 659fc5438..9532cb0f5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java @@ -33,18 +33,18 @@ public class NameScopeTests extends BaseNd4jTest { SDVariable v = sd.var("x"); try(NameScope ns = sd.withNameScope("nameScope")){ SDVariable v2 = sd.var("x2"); - assertEquals("nameScope/x2", v2.getVarName()); + assertEquals("nameScope/x2", v2.name()); assertTrue(sd.getVariables().containsKey("nameScope/x2")); assertEquals("nameScope", sd.currentNameScope()); SDVariable v3 = sd.var("x"); - assertEquals("nameScope/x", v3.getVarName()); + assertEquals("nameScope/x", v3.name()); assertTrue(sd.getVariables().containsKey("nameScope/x")); try(NameScope ns2 = sd.withNameScope("scope2")){ assertEquals("nameScope/scope2", sd.currentNameScope()); SDVariable v4 = sd.var("x"); - assertEquals("nameScope/scope2/x", v4.getVarName()); + assertEquals("nameScope/scope2/x", v4.name()); assertTrue(sd.getVariables().containsKey("nameScope/scope2/x")); } @@ -76,19 +76,19 @@ public class NameScopeTests extends BaseNd4jTest { } SDVariable a = sd.var("a", DataType.FLOAT, 1); - assertEquals("x", x.getVarName()); - assertEquals("s1/y", y.getVarName()); - assertEquals("s1/s2/z", z.getVarName()); - assertEquals("a", a.getVarName()); + assertEquals("x", x.name()); + assertEquals("s1/y", y.name()); + assertEquals("s1/s2/z", z.name()); + assertEquals("a", a.name()); - assertTrue(add.getVarName(), add.getVarName().startsWith("s1/")); - assertEquals("s1/addxy", addWithName.getVarName()); + assertTrue(add.name(), add.name().startsWith("s1/")); + assertEquals("s1/addxy", addWithName.name()); - assertTrue(merge.getVarName(), merge.getVarName().startsWith("s1/s2/")); - assertEquals("s1/s2/mmax", mergeWithName.getVarName()); + assertTrue(merge.name(), merge.name().startsWith("s1/s2/")); + assertEquals("s1/s2/mmax", mergeWithName.name()); Set allowedVarNames = new HashSet<>(Arrays.asList("x", "s1/y", "s1/s2/z", "a", - add.getVarName(), addWithName.getVarName(), merge.getVarName(), mergeWithName.getVarName())); + add.name(), addWithName.name(), merge.name(), mergeWithName.name())); Set allowedOpNames = new HashSet<>(); //Check op names: @@ -102,8 +102,8 @@ public class NameScopeTests extends BaseNd4jTest { //Check fields - Variable, SDOp, etc for(Variable v : sd.getVariables().values()){ - assertTrue(v.getVariable().getVarName(), allowedVarNames.contains(v.getVariable().getVarName())); - assertEquals(v.getName(), v.getVariable().getVarName()); + assertTrue(v.getVariable().name(), allowedVarNames.contains(v.getVariable().name())); + assertEquals(v.getName(), v.getVariable().name()); if(v.getInputsForOp() != null){ for(String s : v.getInputsForOp()){ assertTrue(s, allowedOpNames.contains(s)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java index 5d6b7bfc3..303739ea1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java @@ -108,14 +108,14 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest { sd.fit(ds); } - for(String s : new String[]{"w", "b", badd.getVarName(), add.getVarName(), "l1", "l2"}){ + for(String s : new String[]{"w", "b", badd.name(), add.name(), "l1", "l2"}){ SDVariable gradVar = sd.getVariable(s).gradient(); assertNotNull(s, gradVar); } //Unused: assertFalse(shape.hasGradient()); try{ assertNull(shape.gradient()); } catch (IllegalStateException e){ assertTrue(e.getMessage().contains("only floating point variables")); } - for(String s : new String[]{unused1.getVarName(), unused2.getVarName(), unused3.getVarName()}){ + for(String s : new String[]{unused1.name(), unused2.name(), unused3.name()}){ assertNull(sd.getVariable(s).gradient()); } } @@ -151,20 +151,20 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest { sd.setLossVariables("loss1"); sd.createGradFunction(); for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){ - assertNotNull(v.getVarName(), v.gradient()); + assertNotNull(v.name(), v.gradient()); } for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){ - assertNull(v.getVarName(), v.gradient()); + assertNull(v.name(), v.gradient()); } //Now, set to other loss function sd.setLossVariables("loss2"); sd.createGradFunction(); for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){ - assertNull(v.getVarName(), v.gradient()); + assertNull(v.name(), v.gradient()); } for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){ - assertNotNull(v.getVarName(), v.gradient()); + assertNotNull(v.name(), v.gradient()); } //Train the first side of the graph. The other side should remain unmodified! diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index f19a4ec8b..b67d82110 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -16,12 +16,7 @@ package org.nd4j.autodiff.samediff; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.Assert.*; import static org.junit.Assume.assumeNotNull; import static org.nd4j.linalg.indexing.NDArrayIndex.all; @@ -151,7 +146,7 @@ public class SameDiffTests extends BaseNd4jTest { sd.associateArrayWithVariable(Nd4j.create(new double[]{1, 2, 3, 4, 5, 6}, new long[]{2, 3}), input); - sd.execAndEndResult(); + sd.outputAll(null); nodeA.isPlaceHolder(); } @@ -184,10 +179,10 @@ public class SameDiffTests extends BaseNd4jTest { sd.associateArrayWithVariable(inputArr, input); sd.associateArrayWithVariable(labelArr, label); - INDArray result = sd.execAndEndResult(); + INDArray result = avgMSE.eval(); assertEquals(1, result.length()); - sd.execBackwards(Collections.emptyMap()); + sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); } @Test @@ -208,10 +203,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable x = sameDiff.var("x", arr); SDVariable result = sameDiff.sum(x, 1); //[1,4].sum(1) == [1] - sameDiff.exec(Collections.emptyMap(), sameDiff.outputs()); - INDArray exp = Nd4j.scalar(arr.sumNumber().floatValue()).reshape(1); - INDArray resultArr = result.getArr(); + INDArray resultArr = result.eval(); assertEquals(exp, resultArr); } @@ -226,7 +219,7 @@ public class SameDiffTests extends BaseNd4jTest { Map m = new HashMap<>(); m.put("x", x); m.put("y", y); - INDArray out = sameDiff.exec(m, Collections.singletonList(output.getVarName())).get(output.getVarName()); + INDArray out = sameDiff.output(m, Collections.singletonList(output.name())).get(output.name()); INDArray outputAssertion = x.add(y); assertEquals(outputAssertion, out); } @@ -243,10 +236,9 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable sdTargets = sameDiff.var("targets", targets); SDVariable res = sameDiff.loss().weightedCrossEntropyWithLogits(sdTargets, sdInputs, sdWeights); - sameDiff.exec(Collections.emptyMap(), sameDiff.outputs()); - INDArray resultArray = res.getArr(); - assertArrayEquals(new long[]{1, 5}, res.getShape()); + INDArray resultArray = res.eval(); + assertArrayEquals(new long[]{1, 5}, resultArray.shape()); } @Test @@ -270,7 +262,7 @@ public class SameDiffTests extends BaseNd4jTest { sd.associateArrayWithVariable(inputArr, input); sd.associateArrayWithVariable(labelArr, label); - INDArray result = sd.execAndEndResult(); + INDArray result = score.eval(); assertNotNull(result); //*** Fails Here - Null output *** assertEquals(1, result.length()); } @@ -284,8 +276,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable result = sameDiff.math().cosineSimilarity(x, y, 1); SDVariable addResult = result.add(result); SDVariable finalReshape = sameDiff.reshape(addResult, 1, 2); - sameDiff.exec(Collections.emptyMap(), sameDiff.outputs()); - assertArrayEquals(new long[]{1, 2}, finalReshape.getShape()); + Map out = sameDiff.output(Collections.emptyMap(), finalReshape.name()); + assertArrayEquals(new long[]{1, 2}, out.get(finalReshape.name()).shape()); } @Test @@ -297,7 +289,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable result = sameDiff.mmul(x, y); SDVariable otherResult = result.add(result); Map m = sameDiff.outputAll(null); - assertArrayEquals(new long[]{2, 2}, m.get(result.getVarName()).shape()); + assertArrayEquals(new long[]{2, 2}, m.get(result.name()).shape()); } @@ -308,7 +300,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable x = sameDiff.var("x", arr); SDVariable sigmoid = sameDiff.nn().sigmoid("s", x); INDArray assertion = Transforms.sigmoid(arr); - INDArray eval = sameDiff.exec(Collections.singletonMap("x", arr), Collections.singletonList("s")).get("s"); + INDArray eval = sameDiff.output(Collections.singletonMap("x", arr), Collections.singletonList("s")).get("s"); assertEquals(assertion, eval); } @@ -318,8 +310,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable var = sameDiff.var("one", Nd4j.scalar(1.0)); SDVariable variable2 = sameDiff.var("two", Nd4j.scalar(1.0)); val sum = var.add(variable2); - sum.eval(); - assertArrayEquals(new long[0], sum.getShape()); + INDArray out = sum.eval(); + assertArrayEquals(new long[0], out.shape()); } @@ -330,8 +322,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable firstVar = first.var("one", new long[]{2, 2}); SDVariable secondVar = second.var(firstVar); - assertTrue(firstVar.getArr() == secondVar.getArr()); - assertEquals(firstVar.getVarName(), secondVar.getVarName()); + assertEquals(firstVar.getArr(), secondVar.getArr()); + assertEquals(firstVar.name(), secondVar.name()); } @@ -344,8 +336,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable secondVar = second.var(firstVar); assumeNotNull(firstVar.getArr()); - assertTrue(firstVar.getArr() == secondVar.getArr()); - assertEquals(firstVar.getVarName(), secondVar.getVarName()); + assertEquals(firstVar.getArr(), secondVar.getArr()); + assertEquals(firstVar.name(), secondVar.name()); } @@ -370,7 +362,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable x = sameDiff.var("x", arr); SDVariable s = x.mul("s", x); INDArray assertion = arr.mul(arr); - INDArray eval = sameDiff.exec(Collections.singletonMap("x", arr), Collections.singletonList("s")).get("s"); + INDArray eval = sameDiff.output(Collections.singletonMap("x", arr), Collections.singletonList("s")).get("s"); assertEquals(assertion, eval); } @@ -387,7 +379,7 @@ public class SameDiffTests extends BaseNd4jTest { Map vars = new HashMap<>(); vars.put("x", arr); vars.put("y", yArr); - INDArray eval = sameDiff.exec(vars, Collections.singletonList(sigmoid.getVarName())).get(sigmoid.getVarName()); + INDArray eval = sameDiff.output(vars, Collections.singletonList(sigmoid.name())).get(sigmoid.name()); assertEquals(assertion, eval); } @@ -414,7 +406,7 @@ public class SameDiffTests extends BaseNd4jTest { public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { SDVariable x = sameDiff.var("x", inputs.get("x")); SDVariable y = sameDiff.var("y", inputs.get("y")); - return new SDVariable[]{x.div(y)}; + return new SDVariable[]{x.div("out", y)}; } }, xAndY); @@ -423,14 +415,14 @@ public class SameDiffTests extends BaseNd4jTest { public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { SDVariable x = sameDiff.var("x", inputs.get("x")); SDVariable y = sameDiff.var("y", inputs.get("y")); - return new SDVariable[]{x.rdiv(y)}; + return new SDVariable[]{x.rdiv("out", y)}; } }, xAndY); INDArray assertionForDiv = Nd4j.valueArrayOf(4, 4.0); INDArray assertionForRDiv = Nd4j.valueArrayOf(4, 0.25); - assertEquals(assertionForDiv, sameDiff.getFunction("div").execAndEndResult()); - assertEquals(assertionForRDiv, sameDiff.getFunction("rdiv").execAndEndResult()); + assertEquals(assertionForDiv, sameDiff.getFunction("div").outputSingle(null, "out")); + assertEquals(assertionForRDiv, sameDiff.getFunction("rdiv").outputSingle(null, "out")); } @@ -445,12 +437,12 @@ public class SameDiffTests extends BaseNd4jTest { @Override public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { SDVariable x = sameDiff.var("x", inputs.get("x")); - return new SDVariable[]{sameDiff.math().neg(x)}; + return new SDVariable[]{sameDiff.math().neg("out", x)}; } }, xAndY); INDArray assertionForDiv = Nd4j.valueArrayOf(4, -1); - assertEquals(assertionForDiv, sameDiff.getFunction("neg").execAndEndResult()); + assertEquals(assertionForDiv, sameDiff.getFunction("neg").outputSingle(null, "out")); } @@ -471,7 +463,7 @@ public class SameDiffTests extends BaseNd4jTest { }, inputs); INDArray assertion = sumInput.sum(1); - INDArray out = sameDiff.getFunction("sum").exec(Collections.emptyMap(), Collections.singletonList("sum")) + INDArray out = sameDiff.getFunction("sum").output(Collections.emptyMap(), Collections.singletonList("sum")) .get("sum"); assertEquals(assertion, out); } @@ -484,7 +476,7 @@ public class SameDiffTests extends BaseNd4jTest { */ SameDiff sameDiff = SameDiff.create(); SDVariable sdVariable = sameDiff.var("one", Nd4j.scalar(1.0)); - assumeNotNull(sameDiff.getVariable(sdVariable.getVarName())); + assumeNotNull(sameDiff.getVariable(sdVariable.name())); } @@ -500,7 +492,7 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff sameDiff = SameDiff.create(); SDVariable sdVariable = sameDiff.var("one", Nd4j.scalar(1.0)); SDVariable add = sdVariable.add(1.0); - assertEquals(sameDiff.getVariable(add.getVarName()), add); + assertEquals(sameDiff.getVariable(add.name()), add); } @@ -508,17 +500,8 @@ public class SameDiffTests extends BaseNd4jTest { public void testUpdateVariable() { SameDiff sameDiff = SameDiff.create(); SDVariable one = sameDiff.one("one", new long[]{1, 1}); - sameDiff.updateVariableName(one.getVarName(), "one-diff"); - assertEquals(one.getArr(), sameDiff.getVariable("one-diff").getArr()); - } - - - @Test(expected = IllegalStateException.class) - public void testPlaceHolderWithFullShape() { - val sd = SameDiff.create(); - val placeholder = sd.placeHolder("somevar", DataType.FLOAT, 2, 2); - assertTrue(sd.isPlaceHolder(placeholder.getVarName())); - sd.resolveVariablesWith(Collections.singletonMap(placeholder.getVarName(), Nd4j.linspace(1, 4, 4))); + one.rename("one-diff"); + assertEquals(one.eval(), sameDiff.getVariable("one-diff").eval()); } @@ -554,8 +537,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable left = sameDiff.var("arr", arr); SDVariable right = sameDiff.var("row", row); SDVariable test = left.add(right); - sameDiff.exec(Collections.emptyMap(), sameDiff.outputs()); - assertEquals(assertion, test.getArr()); + assertEquals(assertion, test.eval()); } @@ -604,7 +586,7 @@ public class SameDiffTests extends BaseNd4jTest { sd.associateArrayWithVariable(w, sd.getVariable("W")); sd.associateArrayWithVariable(b, sd.getVariable("b")); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); assertArrayEquals(new long[]{minibatch, nOut}, outArr.shape()); } @@ -644,7 +626,7 @@ public class SameDiffTests extends BaseNd4jTest { sd.associateArrayWithVariable(weightsArr, weights); sd.associateArrayWithVariable(biasArr, bias); - INDArray result = sd.execAndEndResult(); + INDArray result = avgMSE.eval(); } @@ -662,7 +644,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray inArr = Nd4j.create(10, 9, 8); sd.associateArrayWithVariable(inArr, in); - INDArray out = sd.execAndEndResult(); //Exception here, dim0=-1 case only + INDArray out = mean2.eval(); long[] shape = out.shape(); assertArrayEquals(msg, new long[]{10}, shape); @@ -677,10 +659,10 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in = sd.var("in", new long[]{10, 9, 8}); SDVariable mean1 = sd.mean(in, 2); //[10,9] out SDVariable mean2 = sd.mean(mean1, 1); //[10] out - Map m = sd.output((Map)null, mean1.getVarName(), mean2.getVarName()); + Map m = sd.output((Map)null, mean1.name(), mean2.name()); - INDArray m1 = m.get(mean1.getVarName()); - INDArray m2 = m.get(mean2.getVarName()); + INDArray m1 = m.get(mean1.name()); + INDArray m2 = m.get(mean2.name()); assertArrayEquals(new long[]{10, 9}, m1.shape()); assertArrayEquals(new long[]{10}, m2.shape()); @@ -693,20 +675,20 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff sd2 = SameDiff.create(); SDVariable in2 = sd2.var("in", new long[]{10, 9, 8}); SDVariable meanA = sd2.mean(in2, 0); //[9,8] out - sd2.exec(null, sd2.outputs()); - assertArrayEquals(new long[]{9, 8}, meanA.getShape()); + Map out = sd2.outputAll(null); + assertArrayEquals(new long[]{9, 8}, out.get(meanA.name()).shape()); SDVariable meanB = sd2.mean(meanA, 0); //[8] out Map m = sd2.outputAll(null); - assertArrayEquals(new long[]{8}, m.get(meanB.getVarName()).shape()); + assertArrayEquals(new long[]{8}, m.get(meanB.name()).shape()); - assertArrayEquals(meanA.getShape(), m.get(meanA.getVarName()).shape()); - assertArrayEquals(meanB.getShape(), m.get(meanB.getVarName()).shape()); + assertArrayEquals(new long[]{9, 8}, m.get(meanA.name()).shape()); + assertArrayEquals(new long[]{8}, m.get(meanB.name()).shape()); m = sd2.outputAll(null); - INDArray mA = m.get(meanA.getVarName()); - INDArray mB = m.get(meanB.getVarName()); + INDArray mA = m.get(meanA.name()); + INDArray mB = m.get(meanB.name()); assertArrayEquals(new long[]{9, 8}, mA.shape()); assertArrayEquals(new long[]{8}, mB.shape()); @@ -723,9 +705,9 @@ public class SameDiffTests extends BaseNd4jTest { val s = in2.add(5.0); Map map = sd.outputAll(null); - log.info("Result M: {}", map.get(m.getVarName())); - log.info("Result F: {}", map.get(f.getVarName())); - log.info("Result S: {}", map.get(s.getVarName())); + log.info("Result M: {}", map.get(m.name())); + log.info("Result F: {}", map.get(f.name())); + log.info("Result S: {}", map.get(s.name())); } @Test @@ -774,8 +756,8 @@ public class SameDiffTests extends BaseNd4jTest { val input2 = sd.var("input2", vector); val output = sd .mmul("output", input1, input2, MMulTranspose.builder().transposeA(true).transposeB(false).build()); - output.eval(); - assertArrayEquals(new long[]{3, 1}, output.getShape()); + INDArray out = output.eval(); + assertArrayEquals(new long[]{3, 1}, out.shape()); } @Test @@ -807,10 +789,8 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff sameDiff = SameDiff.create(); SDVariable twoByTwo = sameDiff.var("initial", Nd4j.linspace(1, 4, 4, DataType.FLOAT).reshape(2, 2)); SDVariable sum = sameDiff.sum(twoByTwo, Integer.MAX_VALUE); - sameDiff.execBackwards(Collections.emptyMap()); - SameDiff grad = sameDiff.getFunction("grad"); - SDVariable gradArr = sameDiff.grad(twoByTwo.getVarName()); - assertEquals(Nd4j.ones(DataType.FLOAT, 2, 2), gradArr.getArr()); + Map grads = sameDiff.calculateGradients(Collections.emptyMap(), sameDiff.getVariables().keySet()); + assertEquals(Nd4j.ones(DataType.FLOAT, 2, 2), grads.get(twoByTwo.name())); } @@ -830,7 +810,7 @@ public class SameDiffTests extends BaseNd4jTest { }, params); SameDiff logisticGraph = sameDiff.getFunction("rsubop"); - INDArray output = logisticGraph.exec(params, Collections.singletonList("rsub")).get("rsub"); + INDArray output = logisticGraph.output(params, Collections.singletonList("rsub")).get("rsub"); assertEquals(Nd4j.ones(4).muli(-1), output); } @@ -863,7 +843,7 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff logisticGraph = sameDiffOuter.getFunction("oneminuspredictions"); Map inputsSubset = new HashMap<>(); inputsSubset.put("y", inputs.get("y")); - INDArray output = logisticGraph.exec(inputsSubset, Collections.singletonList("rsub")).get("rsub"); + INDArray output = logisticGraph.output(inputsSubset, Collections.singletonList("rsub")).get("rsub"); INDArray assertion = Nd4j.create(new double[]{0, 0, 1, 0}, new int[]{4, 1}); assertEquals(assertion, output); @@ -921,7 +901,7 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff sameDiff = SameDiff.create(); SDVariable twoByTwo = sameDiff.var("first", Nd4j.linspace(1, 4, 4).reshape('c', 2, 2)); SDVariable add = twoByTwo.add(1.0); - INDArray test = sameDiff.execAndEndResult(); + INDArray test = add.eval(); INDArray assertion = Nd4j.linspace(1, 4, 4).reshape('c', 2, 2).add(1.0); assertEquals(assertion, test); } @@ -934,8 +914,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable sdVariable = sameDiff.var("ones", ones); SDVariable result = sdVariable.add(1.0); SDVariable total = sameDiff.sum(result, Integer.MAX_VALUE); - sameDiff.execAndEndResult(); - assertEquals(56, total.getArr().getDouble(0), 1e-1); + INDArray out = total.eval(); + assertEquals(56, out.getDouble(0), 1e-1); } @@ -963,9 +943,9 @@ public class SameDiffTests extends BaseNd4jTest { Map m = sd.outputAll(Collections.emptyMap()); - assertEquals(expMmul, m.get(mmul.getVarName())); - assertEquals(expZ, m.get(z.getVarName())); - assertEquals(expOut, m.get(out.getVarName())); + assertEquals(expMmul, m.get(mmul.name())); + assertEquals(expZ, m.get(z.name())); + assertEquals(expOut, m.get(out.name())); } @Test @@ -1099,7 +1079,7 @@ public class SameDiffTests extends BaseNd4jTest { 0.0, 1); out = sd.nn().tanh("out", out); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); assertArrayEquals(new long[]{1, 10}, outArr.shape()); } @@ -1121,10 +1101,10 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable out = sd.cnn().localResponseNormalization(sdInput, lrn); SDVariable sdOut = sd.math().tanh("out", out); - Map map = sd.output(Collections.emptyMap(), "out", out.getVarName()); + Map map = sd.output(Collections.emptyMap(), "out", out.name()); for (int i = 0; i < 4; i++) { - assertEquals(1, map.get(out.getVarName()).get(all(), NDArrayIndex.point(i), all(), all()).getInt(0)); + assertEquals(1, map.get(out.name()).get(all(), NDArrayIndex.point(i), all(), all()).getInt(0)); } } @@ -1147,8 +1127,8 @@ public class SameDiffTests extends BaseNd4jTest { Map m = sd.outputAll(null); - INDArray meanArray = m.get(mean.getVarName()); - INDArray varArray = m.get(variance.getVarName()); + INDArray meanArray = m.get(mean.name()); + INDArray varArray = m.get(variance.name()); assertEquals(meanArray.getDouble(0), 2.5, 1e-5); assertEquals(varArray.getDouble(0), 1.25, 1e-5); @@ -1176,8 +1156,8 @@ public class SameDiffTests extends BaseNd4jTest { Map m = sd.outputAll(null); - INDArray meanArray = m.get(normMean.getVarName()); - INDArray varArray = m.get(normVariance.getVarName()); + INDArray meanArray = m.get(normMean.name()); + INDArray varArray = m.get(normVariance.name()); assertEquals(meanArray.getDouble(0, 0), 1, 1e-5); assertEquals(meanArray.getDouble(0, 1), 2, 1e-5); @@ -1217,7 +1197,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable out = sd.cnn().depthWiseConv2d(in, dW, b, c); out = sd.math().tanh("out", out); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 val outShape = outArr.shape(); assertArrayEquals(new long[]{mb, depthWise * nIn, 27, 27}, outShape); @@ -1233,11 +1213,11 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v = sd.var("in", arr); SDVariable mean = sd.mean("mean", v); - INDArray out = sd.execAndEndResult(); + INDArray out = mean.eval(); assertEquals(out, arr.mean(Integer.MAX_VALUE)); - sd.execBackwards(Collections.emptyMap()); - INDArray dLdIn = sd.grad("in").getArr(); + Map m = sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); + INDArray dLdIn = m.get("in"); //If L = mean(in) //then dL/dIn = 1/N @@ -1255,11 +1235,11 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v = sd.var("in", arr); SDVariable mean = sd.sum("sum", v); - INDArray out = sd.execAndEndResult(); + INDArray out = mean.eval(); assertEquals(out, arr.sum(Integer.MAX_VALUE)); - sd.execBackwards(Collections.emptyMap()); - INDArray dLdIn = sd.grad("in").getArr(); + Map m = sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); + INDArray dLdIn = m.get("in"); //If L = sum(in) //then dL/dIn = 1 @@ -1278,10 +1258,10 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v = sd.var("in", arr); SDVariable stdev = sd.standardDeviation("stdev", v, biasCorrected); - INDArray out = sd.execAndEndResult(); + INDArray out = stdev.eval(); assertEquals(out, arr.std(biasCorrected, Integer.MAX_VALUE)); - sd.execBackwards(Collections.emptyMap()); + Map g = sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); INDArray dLdIn = sd.grad("in").getArr(); //If L = stdev(in) @@ -1308,11 +1288,11 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v = sd.var("in", arr); SDVariable var = sd.variance("var", v, biasCorrected); - INDArray out = sd.execAndEndResult(); + INDArray out = var.eval(); assertEquals(out, arr.var(biasCorrected, Integer.MAX_VALUE)); - sd.execBackwards(Collections.emptyMap()); - INDArray dLdIn = sd.grad("in").getArr(); + Map g = sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); + INDArray dLdIn = g.get("in"); //If L = var(in) //then dL/dIn = 2/(N-1) * (in-mean) @@ -1336,17 +1316,17 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v = sd.var("in", arr); SDVariable min = sd.min("min", v); - INDArray out = sd.execAndEndResult(); + INDArray out = min.eval(); assertEquals(out, arr.min(Integer.MAX_VALUE)); - sd.execBackwards(Collections.emptyMap()); + Map g = sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); INDArray dLdIn = sd.grad("in").getArr(); //If L = min(in) //then dL/dIn = 1 if in_i == min(in) or 0 otherwise //Note that we don't have an "IsMin" op, so use IsMax(neg(in)) which is equivalent - INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.neg()))[0].castTo(Nd4j.defaultFloatingPointType()); + INDArray exp = Nd4j.exec(new IsMax(arr.neg()))[0].castTo(Nd4j.defaultFloatingPointType()); assertEquals(exp, dLdIn); } @@ -1361,16 +1341,16 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v = sd.var("in", arr); SDVariable min = sd.max("max", v); - INDArray out = sd.execAndEndResult(); + INDArray out = min.eval(); assertEquals(out, arr.max(Integer.MAX_VALUE)); - sd.execBackwards(Collections.emptyMap()); + sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); INDArray dLdIn = sd.grad("in").getArr(); //If L = max(in) //then dL/dIn = 1 if in_i == max(in) or 0 otherwise - INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.dup()))[0].castTo(DataType.DOUBLE); + INDArray exp = Nd4j.exec(new IsMax(arr.dup()))[0].castTo(DataType.DOUBLE); assertEquals(exp, dLdIn); } @@ -1386,10 +1366,10 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable prod = sd.prod("prod", v); double p = arr.prodNumber().doubleValue(); - INDArray out = sd.execAndEndResult(); + INDArray out = prod.eval(); assertEquals(out, arr.prod(Integer.MAX_VALUE)); - sd.execBackwards(Collections.emptyMap()); + Map g = sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); INDArray dLdIn = sd.grad("in").getArr(); //If L = prod(in) @@ -1415,8 +1395,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray expOut = in.getArr().sub(label.getArr()); expOut.muli(expOut); - System.out.println("About to exec"); - INDArray out = sd.execAndEndResult(); //JVM crash + INDArray out = sqDiff.eval(); assertEquals(out, expOut); } @@ -1429,7 +1408,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in = sd.var("in", Nd4j.create(2, 3)); SDVariable expanded = sd.f().expandDims(in, i); - INDArray out = sd.execAndEndResult(); + INDArray out = expanded.eval(); switch (i) { case 0: assertArrayEquals(new long[]{1, 2, 3}, out.shape()); @@ -1452,11 +1431,11 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable var0 = sd.var("in", DataType.DOUBLE, new long[]{3, 4}); SDVariable out = sd.zerosLike("out", var0); - INDArray out1 = sd.execAndEndResult(); + INDArray out1 = out.eval(); assertEquals(Nd4j.zeros(3, 4), out1); sd.associateArrayWithVariable(Nd4j.create(3, 4), var0); - INDArray out2 = sd.execAndEndResult(); + INDArray out2 = out.eval(); assertEquals(Nd4j.zeros(DataType.DOUBLE, 3, 4), out2); } @@ -1466,11 +1445,11 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable var0 = sd.var("in", new long[]{3, 4}); SDVariable out = sd.onesLike("out", var0); - INDArray out1 = sd.execAndEndResult(); + INDArray out1 = out.eval(); assertEquals(Nd4j.ones(3, 4), out1); sd.associateArrayWithVariable(Nd4j.create(3, 4), var0); - INDArray out2 = sd.execAndEndResult(); + INDArray out2 = out.eval(); assertEquals(Nd4j.ones(3, 4), out2); } @@ -1482,12 +1461,12 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable ones = sd.onesLike("ones", var0); SDVariable out = sd.sum("oun", ones); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); assertEquals(Nd4j.scalar(12.0), outArr); - sd.execBackwards(Collections.emptyMap()); + Map m = sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); - assertEquals(Nd4j.create(3, 4), sd.grad("in").getArr()); + assertEquals(Nd4j.create(3, 4), m.get("in")); } @@ -1498,7 +1477,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray a = Nd4j.rand(new long[]{3, 4, 5}); INDArray b = Nd4j.rand(new long[]{3, 4, 5}); - INDArray expOut = Nd4j.getExecutioner().exec(new ManhattanDistance(a, b, 0)); + INDArray expOut = Nd4j.exec(new ManhattanDistance(a, b, 0)); val expShape = new long[]{4, 5}; @@ -1526,7 +1505,7 @@ public class SameDiffTests extends BaseNd4jTest { double maxSum = max.sumNumber().doubleValue(); double jd = 1.0 - minSum / maxSum; - INDArray out = sd.execAndEndResult(); + INDArray out = jaccard.eval(); assertEquals(1, out.length()); assertEquals(jd, out.getDouble(0), 1e-6); @@ -1574,36 +1553,36 @@ public class SameDiffTests extends BaseNd4jTest { case 4: t = sd.gte(in1, in2); expOut = Nd4j.create(DataType.BOOL, ia.shape()); - Nd4j.getExecutioner().exec(new GreaterThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut})); + Nd4j.exec(new GreaterThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut})); break; case 5: t = sd.lte(in1, in2); expOut = Nd4j.create(DataType.BOOL, ia.shape()); - Nd4j.getExecutioner().exec(new LessThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut})); + Nd4j.exec(new LessThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut})); break; case 6: - ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); - ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); + ia = Nd4j.exec(new BernoulliDistribution(ia, 0.5)); + ib = Nd4j.exec(new BernoulliDistribution(ib, 0.5)); t = sd.math().or(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); expOut = Transforms.or(ia, ib); break; case 7: t = sd.max(in1, in2); - expOut = Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0]; + expOut = Nd4j.exec(new Max(ia, ib, ia.dup()))[0]; break; case 8: t = sd.min(in1, in2); - expOut = Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0]; + expOut = Nd4j.exec(new Min(ia, ib, ia.dup()))[0]; break; case 9: - ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); - ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); + ia = Nd4j.exec(new BernoulliDistribution(ia, 0.5)); + ib = Nd4j.exec(new BernoulliDistribution(ib, 0.5)); t = sd.math().and(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); expOut = Transforms.and(ia, ib); break; case 10: - ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); - ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); + ia = Nd4j.exec(new BernoulliDistribution(ia, 0.5)); + ib = Nd4j.exec(new BernoulliDistribution(ib, 0.5)); t = sd.math().xor(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); expOut = Transforms.xor(ia, ib); break; @@ -1612,7 +1591,7 @@ public class SameDiffTests extends BaseNd4jTest { } log.info("Executing: " + i); - INDArray out = sd.execAndEndResult(); + INDArray out = t.eval(); assertEquals(expOut, out); } @@ -1640,22 +1619,22 @@ public class SameDiffTests extends BaseNd4jTest { switch (i) { case 0: t = sd.math().isNonDecreasing(in1); - Nd4j.getExecutioner().exec(new IsNonDecreasing(new INDArray[]{ia}, new INDArray[]{expOut})); + Nd4j.exec(new IsNonDecreasing(new INDArray[]{ia}, new INDArray[]{expOut})); break; case 1: t = sd.math().isStrictlyIncreasing(in1); - Nd4j.getExecutioner().exec(new IsStrictlyIncreasing(new INDArray[]{ia}, new INDArray[]{expOut})); + Nd4j.exec(new IsStrictlyIncreasing(new INDArray[]{ia}, new INDArray[]{expOut})); break; case 2: t = sd.isNumericTensor(in1); - Nd4j.getExecutioner().exec(new IsNumericTensor(new INDArray[]{ia}, new INDArray[]{expOut})); + Nd4j.exec(new IsNumericTensor(new INDArray[]{ia}, new INDArray[]{expOut})); break; default: throw new RuntimeException(); } log.info("Executing: " + i); - INDArray out = sd.execAndEndResult(); + INDArray out = t.eval(); assertEquals(expOut, out); } @@ -1674,7 +1653,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in = sd.var("in", inArr); SDVariable expand = sd.f().expandDims(in, i); - INDArray out = sd.execAndEndResult(); + INDArray out = expand.eval(); INDArray expOut; switch (i) { @@ -1715,7 +1694,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in = sd.var("in", inArr); SDVariable squeeze = sd.f().squeeze(in, i); - INDArray out = sd.execAndEndResult(); + INDArray out = squeeze.eval(); INDArray expOut; switch (i) { @@ -1754,7 +1733,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable expand = sd.expandDims(in, i); SDVariable squeeze = sd.squeeze(expand, i); - INDArray out = sd.execAndEndResult(); + INDArray out = squeeze.eval(); String msg = "expand/Squeeze=" + i + ", source=" + p.getSecond(); @@ -1782,7 +1761,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable squeeze = sd.squeeze(in, i); SDVariable expand = sd.expandDims(squeeze, i); - INDArray out = sd.execAndEndResult(); + INDArray out = expand.eval(); String msg = "expand/Squeeze=" + i + ", source=" + p.getSecond(); @@ -1801,8 +1780,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable labelsVar = sd.constant("labels", labels); SDVariable predictionsVar = sd.constant("predictions", pred); SDVariable weightsVar = sd.constant("weights", weights); - sd.math().confusionMatrix("cm", labelsVar, predictionsVar, numClasses, weightsVar); - INDArray out = sd.execAndEndResult(); + SDVariable cm = sd.math().confusionMatrix("cm", labelsVar, predictionsVar, numClasses, weightsVar); + INDArray out = cm.eval(); INDArray exp = Nd4j.create(new float[][]{{0, 0, 0, 0, 0}, {0, 0, 10, 0, 0}, {0, 0, 100, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 1000}}).castTo(DataType.INT); @@ -1821,7 +1800,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in = sd.var("in", inArr); SDVariable argmax = sd.argmax("argmax", in, dim); - INDArray out = sd.execAndEndResult(); + INDArray out = argmax.eval(); INDArray exp = Nd4j.argMax(inArr, dim); @@ -1841,7 +1820,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in = sd.var("in", inArr); SDVariable argmin = sd.argmin("argmin", in, dim); - INDArray out = sd.execAndEndResult(); + INDArray out = argmin.eval(); INDArray exp = Nd4j.argMax(inArr.neg(), dim); //argmin(x) == argmax(-x) @@ -2004,7 +1983,7 @@ public class SameDiffTests extends BaseNd4jTest { System.out.println(in); INDArray exp = Nd4j.pullRows(in, 1, new int[]{0, 1, 5}); //Along dimension 1 -> equiv to "indexes for axis 0" - INDArray act = sd.execAndEndResult(); + INDArray act = gather.eval(); assertEquals(exp, act); } @@ -2022,7 +2001,7 @@ public class SameDiffTests extends BaseNd4jTest { .addOutputs(out) .build(); - Nd4j.getExecutioner().exec(op); + Nd4j.exec(op); System.out.println(out); @@ -2058,10 +2037,9 @@ public class SameDiffTests extends BaseNd4jTest { INDArray expNaN = Nd4j.create(new boolean[]{false, false}); SDVariable isnan = sd.math().isNaN(in); - sd.exec(Collections.emptyMap(), sd.outputs()); - assertEquals(expFinite, finite.getArr()); - assertEquals(expInfinite, infinite.getArr()); - assertEquals(expNaN, isnan.getArr()); + assertEquals(expFinite, finite.eval()); + assertEquals(expInfinite, infinite.eval()); + assertEquals(expNaN, isnan.eval()); } @@ -2155,7 +2133,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable write0 = tensorArray.write(var2, 0, var1); SDVariable write1 = tensorArray.write(write0, 1, var2); SDVariable result = tensorArray.stack(write1); - sd.exec(null, result.getVarName()); + sd.output((Map)null, result.name()); assertEquals(Nd4j.pile(arr1, arr2), result.eval()); } @@ -2256,12 +2234,12 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable sum = in.sum(1); INDArray exp = in.getArr().sum(1).reshape(3); - INDArray out = sd.execAndEndResult(); + INDArray out = sum.eval(); assertEquals(exp, out); //Now, replace with minibatch 5: in.setArray(Nd4j.linspace(1, 20, 20).reshape(5, 4)); - INDArray out2 = sd.execAndEndResult(); + INDArray out2 = sum.eval(); assertArrayEquals(new long[]{5}, out2.shape()); exp = in.getArr().sum(1).reshape(5); @@ -2276,12 +2254,12 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable sum = in.argmax(1); INDArray exp = in.getArr().argMax(1).reshape(3); - INDArray out = sd.execAndEndResult(); + INDArray out = sum.eval(); assertEquals(exp, out); //Now, replace with minibatch 5: in.setArray(Nd4j.linspace(1, 20, 20).reshape(5, 4)); - INDArray out2 = sd.execAndEndResult(); + INDArray out2 = sum.eval(); assertArrayEquals(new long[]{5}, out2.shape()); exp = in.getArr().argMax(1).reshape(5); @@ -2300,12 +2278,11 @@ public class SameDiffTests extends BaseNd4jTest { gradMap.put("out", externalGrad); ExternalErrorsFunction fn = sd.f().externalErrors(out); - sd.execAndEndResult(); Map m = new HashMap<>(); m.put("out-grad", externalGrad); - sd.execBackwards(m); + Map grads = sd.calculateGradients(m, sd.getVariables().keySet()); - INDArray gradVar = var.getGradient().getArr(); + INDArray gradVar = grads.get(var.name()); assertEquals(externalGrad.mul(0.5), gradVar); @@ -2313,7 +2290,7 @@ public class SameDiffTests extends BaseNd4jTest { externalGrad = Nd4j.linspace(1, 12, 12).reshape(3, 4).muli(10); m.put("out-grad", externalGrad); - sd.execBackwards(m); + grads = sd.calculateGradients(m, sd.getVariables().keySet()); gradVar = var.getGradient().getArr(); @@ -2332,26 +2309,26 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable out = sd.mmul(in, w); SDVariable loss = out.std("out", true); - INDArray outArr = sd.execAndEndResult().dup(); + INDArray outArr = loss.eval(); // sd.execBackwards(Collections.emptyMap()); - Map grads = sd.calculateGradients(null, in.getVarName(), w.getVarName(), out.getVarName()); + Map grads = sd.calculateGradients(null, in.name(), w.name(), out.name()); Map origGrad = new HashMap<>(); - origGrad.put("in", grads.get(in.getVarName()).dup()); - origGrad.put("w", grads.get(w.getVarName()).dup()); - origGrad.put("out", grads.get(out.getVarName()).dup()); + origGrad.put("in", grads.get(in.name()).dup()); + origGrad.put("w", grads.get(w.name()).dup()); + origGrad.put("out", grads.get(out.name()).dup()); in.getArr().assign(Nd4j.rand(in.getArr().shape())); - INDArray outArr2 = sd.execAndEndResult(); + INDArray outArr2 = loss.eval(); // sd.execBackwards(Collections.emptyMap()); - grads = sd.calculateGradients(null, in.getVarName(), w.getVarName(), out.getVarName()); + grads = sd.calculateGradients(null, in.name(), w.name(), out.name()); assertNotEquals(outArr, outArr2); //Ensure gradients are also changed: - assertNotEquals(origGrad.get("in"), grads.get(in.getVarName())); - assertNotEquals(origGrad.get("w"), grads.get(w.getVarName())); - assertNotEquals(origGrad.get("out"), grads.get(out.getVarName())); + assertNotEquals(origGrad.get("in"), grads.get(in.name())); + assertNotEquals(origGrad.get("w"), grads.get(w.name())); + assertNotEquals(origGrad.get("out"), grads.get(out.name())); } @Test @@ -2361,25 +2338,25 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable out = in.mul(2.0); SDVariable loss = out.std("out", true); - INDArray outArr = sd.execAndEndResult().dup(); - Map grads = sd.calculateGradients(null, in.getVarName(), out.getVarName()); + INDArray outArr = loss.eval(); + Map grads = sd.calculateGradients(null, in.name(), out.name()); Map origGrad = new HashMap<>(); - origGrad.put("in", grads.get(in.getVarName()).dup()); - origGrad.put("out", grads.get(out.getVarName()).dup()); + origGrad.put("in", grads.get(in.name()).dup()); + origGrad.put("out", grads.get(out.name()).dup()); double stdBefore = in.getArr().stdNumber().doubleValue(); in.getArr().assign(Nd4j.rand(in.getArr().shape())); double stdAfter = in.getArr().stdNumber().doubleValue(); System.out.println("Before vs. after: " + stdBefore + ", " + stdAfter); - INDArray outArr2 = sd.execAndEndResult(); - grads = sd.calculateGradients(null, in.getVarName(), out.getVarName()); + INDArray outArr2 = loss.eval(); + grads = sd.calculateGradients(null, in.name(), out.name()); assertNotEquals(outArr, outArr2); //Ensure gradients are also changed: - assertNotEquals(origGrad.get("in"), grads.get(in.getVarName())); - assertNotEquals(origGrad.get("out"), grads.get(out.getVarName())); + assertNotEquals(origGrad.get("in"), grads.get(in.name())); + assertNotEquals(origGrad.get("out"), grads.get(out.name())); } @Test @@ -2416,8 +2393,8 @@ public class SameDiffTests extends BaseNd4jTest { grad = Nd4j.linspace(1, 8, 8).reshape(2, 4); phMap.put(fn.getGradPlaceholderName(), grad); - sd.execBackwards(phMap); - INDArray inGrad = in.getGradient().getArr(); + Map grads = sd.calculateGradients(phMap, sd.getVariables().keySet()); + INDArray inGrad = grads.get(in.name()); assertArrayEquals(new long[]{2, 5}, inGrad.shape()); } @@ -2473,7 +2450,7 @@ public class SameDiffTests extends BaseNd4jTest { Map placeholders = new HashMap<>(); placeholders.put("x", x); placeholders.put("y", y); - Map grads = sd.calculateGradients(placeholders, xSd.getVarName(), ySd.getVarName()); + Map grads = sd.calculateGradients(placeholders, xSd.name(), ySd.name()); INDArray xGradientEnforced = grads.get("x"); assertNotNull(xGradientEnforced); } @@ -2537,6 +2514,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray inArr = Nd4j.rand(DataType.FLOAT, 1, 3); in.setArray(inArr); INDArray inArr2 = Nd4j.rand(DataType.FLOAT, 3, 4); + in2.setArray(inArr2); TrainingConfig c = TrainingConfig.builder() .updater(new Adam(0.1)) @@ -2732,12 +2710,12 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff sd = SameDiff.create(); SDVariable linspace = sd.linspace("at", DataType.DOUBLE, 1, 15, 15); SDVariable a = sd.reshape("a", linspace, 3, 5); - SDVariable b = sd.one("b", DataType.DOUBLE, 3, 5); + SDVariable b = sd.var("b", Nd4j.ones(DataType.DOUBLE, 3, 5)); SDVariable out = a.mul(b); + out.markAsLoss(); out.eval(); - sd.execBackwards(null, "a"); out.eval(); sd.grad("a").eval(); @@ -2752,12 +2730,12 @@ public class SameDiffTests extends BaseNd4jTest { public void testNonScalarOutput2() { SameDiff sd = SameDiff.create(); SDVariable a = sd.reshape("a", sd.linspace("at", DataType.DOUBLE, 1, 15, 15), 3, 5); - SDVariable b = sd.one("b", DataType.DOUBLE, 3, 5); + SDVariable b = sd.var("b", Nd4j.ones(DataType.DOUBLE, 3, 5)); SDVariable out = a.mul(b).mean(1); + out.markAsLoss(); out.eval(); - sd.execBackwards(null, Lists.asList("a", new String[]{})); //System.out.println(out.eval()); INDArray actGrad = sd.grad("a").eval(); @@ -2772,15 +2750,16 @@ public class SameDiffTests extends BaseNd4jTest { public void testNonScalarOutput3() { SameDiff sd = SameDiff.create(); SDVariable a = sd.reshape("a", sd.linspace("at", DataType.DOUBLE, 1, 15, 15), 3, 5); - SDVariable b = sd.one("b", DataType.DOUBLE, 3, 5);//.add(3); + SDVariable b = sd.var("b", Nd4j.ones(DataType.DOUBLE, 3, 5));//.add(3); SDVariable out = a.mul(b).mean(0, 1); + out.markAsLoss(); out.eval(); - sd.execBackwards(null, "a"); + Map g = sd.calculateGradients(null, "a"); //System.out.println(out.eval()); - INDArray gradAct = sd.grad("a").eval(); + INDArray gradAct = g.get("a"); INDArray expGrad = Nd4j.valueArrayOf(new long[]{3, 5}, 1.0 / 12, DataType.DOUBLE); String err = OpValidation.validate(new TestCase(sd).gradientCheck(true)); @@ -2798,7 +2777,7 @@ public class SameDiffTests extends BaseNd4jTest { Map m = new HashMap<>(); m.put("b", Nd4j.rand(DataType.DOUBLE, 4, 5)); - sd.execBackwards(m, "a", "b"); + Map g = sd.calculateGradients(m, "a", "b"); b.setArray(m.get("b")); @@ -2820,7 +2799,7 @@ public class SameDiffTests extends BaseNd4jTest { final SDVariable out = a.mmul(b).add(c.mmul(d)).sum(); out.markAsLoss(); - sd.execBackwards(null); + Map g = sd.calculateGradients(null, sd.getVariables().keySet()); } @Test @@ -2832,7 +2811,7 @@ public class SameDiffTests extends BaseNd4jTest { a.add(b.add(c)).sum().markAsLoss(); - sd.execBackwards(Collections.singletonMap("c", Nd4j.rand(4, 4))); + sd.calculateGradients(Collections.singletonMap("c", Nd4j.rand(4, 4)), sd.getVariables().keySet()); assertNotNull(sd.grad("a")); assertNull(sd.grad("b")); assertNull(sd.grad("c")); @@ -2908,16 +2887,16 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v2 = sd.var("y", Nd4j.rand(DataType.FLOAT, 4, 5)); SDVariable v3 = v1.mmul("oldName", v2); - INDArray out = sd.execSingle(null, "oldName"); + INDArray out = sd.outputSingle(null, "oldName"); SDVariable renamed = v3.rename("newName"); assertTrue(v3 == renamed); - assertEquals("newName", renamed.getVarName()); + assertEquals("newName", renamed.name()); assertNull(sd.getVariable("oldName")); assertNotNull(sd.getVariable("newName")); - INDArray out2 = sd.execSingle(null, "newName"); + INDArray out2 = sd.outputSingle(null, "newName"); assertEquals(out, out2); } @@ -2931,7 +2910,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v3 = v1.mmul("oldName", v2); SDVariable v4 = v3.std("out", false); - INDArray out = sd.execSingle(Collections.singletonMap("x", Nd4j.rand(DataType.FLOAT, 3, 4)), "out"); + INDArray out = sd.outputSingle(Collections.singletonMap("x", Nd4j.rand(DataType.FLOAT, 3, 4)), "out"); sd.setTrainingConfig(TrainingConfig.builder() .updater(new Adam(1e-3)) @@ -3029,7 +3008,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray inputArr = Nd4j.rand(DataType.FLOAT, minibatch, nIn); - Map m = sd.exec(Collections.singletonMap("in", inputArr), "softmax"); + Map m = sd.output(Collections.singletonMap("in", inputArr), "softmax"); assertEquals(1, m.size()); assertTrue(m.containsKey("softmax")); @@ -3039,7 +3018,7 @@ public class SameDiffTests extends BaseNd4jTest { Map allPh = new HashMap<>(); allPh.put("in", inputArr); allPh.put("label", labelUnused); - m = sd.exec(allPh, "softmax"); + m = sd.output(allPh, "softmax"); assertEquals(1, m.size()); assertTrue(m.containsKey("softmax")); INDArray out2 = m.get("softmax"); @@ -3069,7 +3048,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray inputArr = Nd4j.rand(DataType.FLOAT, minibatch, nIn); - Map m = sd.exec(Collections.singletonMap("in", inputArr), "softmax"); + Map m = sd.output(Collections.singletonMap("in", inputArr), "softmax"); assertEquals(1, m.size()); assertTrue(m.containsKey("softmax")); @@ -3080,7 +3059,7 @@ public class SameDiffTests extends BaseNd4jTest { allPh.put("in", inputArr); allPh.put("label", labelUnused); allPh.put("in2", Nd4j.scalar(1.0f)); - m = sd.exec(allPh, "softmax"); + m = sd.output(allPh, "softmax"); assertEquals(1, m.size()); assertTrue(m.containsKey("softmax")); INDArray out2 = m.get("softmax"); @@ -3104,7 +3083,7 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(DataType.FLOAT, tanh.dataType()); assertEquals(DataType.FLOAT, stdev.dataType()); - Map out = sd.exec(null, "x", "y", "z", "tanh", "stdev"); + Map out = sd.output((Map)null, "x", "y", "z", "tanh", "stdev"); for (Map.Entry e : out.entrySet()) { assertEquals(e.getKey(), DataType.FLOAT, e.getValue().dataType()); } @@ -3123,7 +3102,7 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(DataType.DOUBLE, tanh.dataType()); assertEquals(DataType.DOUBLE, stdev.dataType()); - out = sd.exec(null, "x", "y", "z", "tanh", "stdev"); + out = sd.output((Map)null, "x", "y", "z", "tanh", "stdev"); for (Map.Entry e : out.entrySet()) { assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType()); } @@ -3152,7 +3131,7 @@ public class SameDiffTests extends BaseNd4jTest { Map ph = Collections.singletonMap("x", Nd4j.rand(DataType.FLOAT, 3, 4)); - Map out = sd.exec(ph, "x", "y", "xD", "yD", "a", "r"); + Map out = sd.output(ph, "x", "y", "xD", "yD", "a", "r"); for (Map.Entry e : out.entrySet()) { if (e.getKey().equals("x") || e.getKey().equals("y")) { assertEquals(e.getKey(), DataType.FLOAT, e.getValue().dataType()); @@ -3175,7 +3154,7 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(DataType.DOUBLE, add.dataType()); assertEquals(DataType.DOUBLE, relu.dataType()); - out = sd.exec(ph, "x", "y", "xD", "yD", "a", "r"); + out = sd.output(ph, "x", "y", "xD", "yD", "a", "r"); for (Map.Entry e : out.entrySet()) { assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType()); } @@ -3210,9 +3189,9 @@ public class SameDiffTests extends BaseNd4jTest { assertNotNull(w.gradient()); assertNotNull(b.gradient()); - sd.execBackwards(Collections.singletonMap("in", in)); - assertNotNull(ph.gradient().getArr()); - assertNotNull(w.gradient().getArr()); + Map m = sd.calculateGradients(Collections.singletonMap("in", in), ph.name(), w.name()); + assertNotNull(m.get(ph.name())); + assertNotNull(m.get(w.name())); } else { sd.createGradFunction(); assertNull(ph.gradient()); @@ -3287,7 +3266,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray out = sum[1].eval(); assertEquals(15, out.getInt(0)); - String outName = sum[1].getVarName(); + String outName = sum[1].name(); SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); @@ -3313,7 +3292,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray out = sum[1].eval(); assertEquals(35, out.getInt(0)); - String outName = sum[1].getVarName(); + String outName = sum[1].name(); SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); @@ -3339,7 +3318,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray out = sum[1].eval(); assertEquals(115, out.getInt(0)); - String outName = sum[1].getVarName(); + String outName = sum[1].name(); SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); @@ -3423,4 +3402,28 @@ public class SameDiffTests extends BaseNd4jTest { String err = OpValidation.validate(tc); assertNull(err); } + + @Test + public void testSameDiffSeedReproducibilityVarInit() { + + SameDiff sd0 = SameDiff.create(); + SameDiff sd1 = SameDiff.create(); + Nd4j.getRandom().setSeed(12345); + SDVariable rand0 = sd0.var("random", new UniformInitScheme('c', 3), DataType.FLOAT, 3, 1); + + Nd4j.getRandom().setSeed(12345); + SDVariable rand1 = sd1.var("random", new UniformInitScheme('c', 3), DataType.FLOAT, 3, 1); + + + Nd4j.getRandom().setSeed(0); + System.out.println(rand0.eval()); + + Nd4j.getRandom().setSeed(0); + System.out.println(rand1.eval()); + + INDArray a0 = rand0.eval(); + Nd4j.getRandom().setSeed(0); + INDArray a1 = rand1.eval(); + assertEquals(a0, a1); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java index ffb7c319d..157fc5b46 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java @@ -109,6 +109,8 @@ public class FileReadWriteTests extends BaseNd4jTest { for (int i = 0; i < s.outputsLength(); i++) { outputs.add(s.outputs(i)); } + if(outputs.isEmpty()) + outputs = null; assertEquals(sd.outputs(), outputs); //Check variables diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java index 417652dcc..44b465fd3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java @@ -63,7 +63,7 @@ public class UIListenerTest { Map m = new HashMap<>(); iter.reset(); m.put("in", iter.next().getFeatures()); - INDArray out = sd.execSingle(m, "softmax"); + INDArray out = sd.outputSingle(m, "softmax"); assertNotNull(out); assertArrayEquals(new long[]{150, 3}, out.shape()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java index ddf4775a8..d320ad6e3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java @@ -78,7 +78,7 @@ public class ExecutionTests extends BaseNd4jTest { val tg = TFGraphMapper.importGraphTxt(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream(), null, null); System.out.println(tg.summary()); - Map result_0 = tg.exec(Collections.emptyMap(), tg.outputs()); + Map result_0 = tg.outputAll(null); val exp_0 = Nd4j.create(DataType.FLOAT, 3).assign(3.0); assertEquals(exp_0, result_0.get("Sum")); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java index c70dfa436..118897a27 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java @@ -174,7 +174,7 @@ public class BERTGraphTest extends BaseNd4jTest { //Find pre-dropout input variable: SDVariable newOut = null; for(SDVariable v : inputs){ - if(v.getVarName().endsWith("/BiasAdd") || v.getVarName().endsWith("/Softmax") || v.getVarName().endsWith("/add_1") || v.getVarName().endsWith("/Tanh")){ + if(v.name().endsWith("/BiasAdd") || v.name().endsWith("/Softmax") || v.name().endsWith("/add_1") || v.name().endsWith("/Tanh")){ newOut = v; break; } @@ -249,7 +249,7 @@ public class BERTGraphTest extends BaseNd4jTest { placeholderValues.put("IteratorGetNext:1", mask); placeholderValues.put("IteratorGetNext:4", segmentIdxs); - Map out = sd.exec(placeholderValues, "loss/Softmax"); + Map out = sd.output(placeholderValues, "loss/Softmax"); INDArray softmax = out.get("loss/Softmax"); // System.out.println("OUTPUT - Softmax"); // System.out.println(softmax); @@ -335,8 +335,8 @@ public class BERTGraphTest extends BaseNd4jTest { //For training, convert weights and biases from constants to variables: for(SDVariable v : sd.variables()){ - if(v.isConstant() && v.dataType().isFPType() && !v.getArr().isScalar() && !floatConstants.contains(v.getVarName())){ //Skip scalars - trainable params - log.info("Converting to variable: {} - dtype: {} - shape: {}", v.getVarName(), v.dataType(), Arrays.toString(v.getArr().shape())); + if(v.isConstant() && v.dataType().isFPType() && !v.getArr().isScalar() && !floatConstants.contains(v.name())){ //Skip scalars - trainable params + log.info("Converting to variable: {} - dtype: {} - shape: {}", v.name(), v.dataType(), Arrays.toString(v.getArr().shape())); v.convertToVariable(); } } @@ -393,14 +393,14 @@ public class BERTGraphTest extends BaseNd4jTest { placeholderValues.put("IteratorGetNext:4", segmentIdxs); placeholderValues.put("label", labelArr); - INDArray lossArr = sd.exec(placeholderValues, "loss").get("loss"); + INDArray lossArr = sd.output(placeholderValues, "loss").get("loss"); assertTrue(lossArr.isScalar()); double scoreBefore = lossArr.getDouble(0); for( int i=0; i<5; i++ ){ sd.fit(mds); } - lossArr = sd.exec(placeholderValues, "loss").get("loss"); + lossArr = sd.output(placeholderValues, "loss").get("loss"); assertTrue(lossArr.isScalar()); double scoreAfter = lossArr.getDouble(0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java index 43a9b2911..77d3c759e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java @@ -105,11 +105,9 @@ public class ValidateZooModelPredictions extends BaseNd4jTest { //Perform inference List inputs = sd.inputs(); assertEquals(1, inputs.size()); - List outputs = sd.outputs(); - assertEquals(1, outputs.size()); - String out = outputs.get(0); - Map m = sd.exec(Collections.singletonMap(inputs.get(0), img), out); + String out = "MobilenetV1/Predictions/Softmax"; + Map m = sd.output(Collections.singletonMap(inputs.get(0), img), out); INDArray outArr = m.get(out); @@ -167,7 +165,7 @@ public class ValidateZooModelPredictions extends BaseNd4jTest { assertEquals(1, inputs.size()); String out = "softmax_tensor"; - Map m = sd.exec(Collections.singletonMap(inputs.get(0), img), out); + Map m = sd.output(Collections.singletonMap(inputs.get(0), img), out); INDArray outArr = m.get(out); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java index 745f5f8fd..22b8b4492 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java @@ -106,7 +106,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { //g.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/mnist.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).build()); - g.execAndEndResult(); + g.outputAll(null); } @@ -129,7 +129,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { val graph = TFGraphMapper.importGraph(new ClassPathResource("/tf_graphs/argmax.pb.txt").getInputStream()); log.info(graph.asFlatPrint()); - val result = graph.execAndEndResult(); + val result = graph.outputAll(null).get(graph.outputs().get(0)); val exp = Nd4j.createFromArray(new long[]{2, 2, 2}); @@ -222,7 +222,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { graph.var("Placeholder", p0); graph.var("Placeholder_1", p1); - val res = graph.execAndEndResult(); + val res = graph.outputAll(null).get(graph.outputs().get(0)); @@ -341,7 +341,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { val constIn = tg.getVariable("StridedSlice/input"); assertNotNull(constIn); - val arr = tg.getArrForVarName(constIn.getVarName()); + val arr = tg.getArrForVarName(constIn.name()); assertEquals(139.5, arr.sumNumber().doubleValue(), 1e-5); @@ -651,9 +651,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { INDArray input = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4); INDArray expectedOutput = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4).addRowVector(Nd4j.linspace(1,4,4, DataType.FLOAT)); - INDArray actual = graph.execSingle(Collections.singletonMap("input",input), graph.outputs().get(0)); + INDArray actual = graph.outputSingle(Collections.singletonMap("input",input), graph.outputs().get(0)); assertEquals(input,graph.getVariable("input").getArr()); - assertArrayEquals(input.shape(),graph.getShapeForVarName(graph.getVariable("input").getVarName())); assertEquals(expectedOutput,actual); } @@ -665,13 +664,13 @@ public class TensorFlowImportTest extends BaseNd4jTest { val variables = new HashMap(); for (val var : tg.variables()) { - variables.put(var.getVarName(), var); + variables.put(var.name(), var); } val functions = new HashMap(); for (val func: tg.ops()) { val ownName = func.getOwnName(); - String outName = func.outputVariables()[0].getVarName(); + String outName = func.outputVariables()[0].name(); assertTrue("Missing ownName: [" + ownName +"]",variables.containsKey(ownName)); assertEquals(ownName, outName); @@ -704,7 +703,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb")); //log.info("{}", tg.asFlatPrint()); - val array = tg.execAndEndResult(); + val array = tg.outputAll(null).get(tg.outputs().get(0)); val exp = Nd4j.create(2, 2).assign(1); assertNotNull(array); assertEquals(exp, array); @@ -723,7 +722,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { //log.info("{}", tg.asFlatPrint()); - val array = tg.execAndEndResult(); + val array = tg.outputAll(null).get(tg.outputs().get(0)); val exp = Nd4j.create(2, 2).assign(1); assertNotNull(array); assertEquals(exp, array); @@ -741,7 +740,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { //log.info("{}", tg.asFlatPrint()); /* - val array = tg.execAndEndResult(); + val array = tg.outputAll(null).get(tg.outputs().get(0)); val exp = Nd4j.create(2, 2).assign(2); assertNotNull(array); assertEquals(exp, array);*/ @@ -759,7 +758,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { //log.info("{}", tg.asFlatPrint()); - val array = tg.execAndEndResult(); + val array = tg.outputAll(null).get(tg.outputs().get(0)); val exp = Nd4j.create(2, 2).assign(4); assertNotNull(array); assertEquals(exp, array); @@ -780,7 +779,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { //log.info("{}", tg.asFlatPrint()); - val array = tg.execAndEndResult(); + INDArray array = tg.outputAll(null).get(tg.outputs().get(0)); val exp = Nd4j.create(2, 2).assign(-1); assertNotNull(array); assertEquals(exp, array); @@ -800,7 +799,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { //log.info("{}", tg.asFlatPrint()); - val array = tg.execAndEndResult(); + val array = tg.outputAll(null).get(tg.outputs().get(0)); val exp = Nd4j.create(2, 2).assign(-3); assertNotNull(array); assertEquals(exp, array); @@ -822,7 +821,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { //log.info("{}", tg.asFlatPrint()); - val array = tg.execAndEndResult(); + Map m = tg.outputAll(null); + val array = m.get(tg.outputs().get(0)); //val array = tg.getVariable("output").getArr(); val exp = Nd4j.create(2, 2).assign(15.0); assertNotNull(array); @@ -968,7 +968,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { assertNotNull(tg); val input_matrix = Nd4j.ones(3, 2); - val array = tg.execSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)); + val array = tg.outputSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)); val exp = Nd4j.create(new float[] {1, 1, 2, 2, 3, 3}, new int[]{3, 2}); @@ -982,7 +982,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { val input_matrix = Nd4j.ones(3, 2); - val array = tg.exec(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)); + val array = tg.output(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)).get(tg.outputs().get(0)); val exp = Nd4j.create(new float[] {2, 2}, new int[]{2}); @@ -997,7 +997,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_unstack.pb.txt").getInputStream()); assertNotNull(tg); - val array = tg.execSingle(Collections.emptyMap(), tg.outputs().get(0)); + val array = tg.outputSingle(Collections.emptyMap(), tg.outputs().get(0)); val exp = Nd4j.create(new float[] {5, 6, 7, 8}, new int[]{4}); @@ -1011,7 +1011,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { val input_matrix = Nd4j.linspace(1, 10, 10, DataType.FLOAT).reshape(5, 2); log.info("Graph: {}", tg.asFlatPrint()); - val array = tg.execSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)); + val array = tg.outputSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)); val exp = Nd4j.create(new float[] {3,6, 9,12, 15,18, 21,24, 27,30}, new int[]{5, 2}); @@ -1023,7 +1023,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/losses/log_loss_rank2_axis1_SUM_OVER_BATCH_SIZE/frozen_model.pb").getInputStream()); - tg.execAndEndResult(); + tg.outputAll(null); } @Test @@ -1040,7 +1040,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { for (int e = 0; e < 1000; e++){ val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/reduce_any/rank0/frozen_model.pb").getInputStream()); - Map result = tg.exec(Collections.emptyMap(), tg.outputs()); + Map result = tg.output(Collections.emptyMap(), tg.outputs()); assertNotNull(result); assertTrue(result.size() > 0); @@ -1052,7 +1052,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/transforms/logicalxor_3,4_3,4/frozen_model.pb").getInputStream()); - tg.execAndEndResult(); + tg.outputAll(null); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java index 5b4c84b4a..06c85b289 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java @@ -118,7 +118,7 @@ public class ImportModelDebugger { List outputs = sd.outputs(); - sd.exec(ph, outputs); + sd.output(ph, outputs); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java new file mode 100644 index 000000000..c1e5a8704 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java @@ -0,0 +1,93 @@ +package org.nd4j.linalg.convolution; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.ops.transforms.Transforms; +import org.nd4j.resources.Resources; + +import java.io.File; +import java.util.*; + +import static org.junit.Assert.*; + +public class DeconvTests extends BaseNd4jTest { + + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + + public DeconvTests(Nd4jBackend backend) { + super(backend); + } + + @Override + public char ordering() { + return 'c'; + } + + @Test + public void compareKeras() throws Exception { + File f = testDir.newFolder(); + Resources.copyDirectory("keras/deconv", f); + + File[] files = f.listFiles(); + + Set tests = new HashSet<>(); + for(File file : files){ + String n = file.getName(); + if(!n.startsWith("mb")) + continue; + + int idx = n.lastIndexOf('_'); + String name = n.substring(0, idx); + tests.add(name); + } + + List l = new ArrayList<>(tests); + Collections.sort(l); + assertFalse(l.isEmpty()); + + for(String s : l){ + String s2 = s.replaceAll("[a-zA-Z]", ""); + String[] nums = s2.split("_"); + int mb = Integer.parseInt(nums[0]); + int k = Integer.parseInt(nums[1]); + int size = Integer.parseInt(nums[2]); + int stride = Integer.parseInt(nums[3]); + boolean same = s.contains("same"); + int d = Integer.parseInt(nums[5]); + boolean nchw = s.contains("nchw"); + + INDArray w = Nd4j.readNpy(new File(f, s + "_W.npy")); + INDArray b = Nd4j.readNpy(new File(f, s + "_b.npy")); + INDArray in = Nd4j.readNpy(new File(f, s + "_in.npy")).castTo(DataType.FLOAT); + INDArray expOut = Nd4j.readNpy(new File(f, s + "_out.npy")); + + CustomOp op = DynamicCustomOp.builder("deconv2d") + .addInputs(in, w, b) + .addIntegerArguments( + k, k, + stride, stride, + 0, 0, //padding + d, d, + same ? 1 : 0, + nchw ? 0 : 1) + .callInplace(false) + .build(); + INDArray out = Nd4j.create(op.calculateOutputShape().get(0)); + out.assign(Double.NaN); + op.addOutputArgument(out); + Nd4j.exec(op); + + boolean eq = expOut.equalsWithEps(out, 1e-4); + assertTrue(eq); + } + } +} diff --git a/nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java b/nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java index bcdfba202..ec4739b86 100644 --- a/nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java +++ b/nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java @@ -21,8 +21,6 @@ import org.nd4j.config.ND4JEnvironmentVars; import org.nd4j.config.ND4JSystemProperties; import org.nd4j.context.Nd4jContext; import org.nd4j.linalg.io.Resource; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.File; import java.io.IOException; @@ -152,6 +150,9 @@ public abstract class Nd4jBackend { */ public static Nd4jBackend load() throws NoAvailableBackendException { + String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true"); + boolean logInit = Boolean.parseBoolean(logInitProperty); + List backends = new ArrayList<>(1); ServiceLoader loader = ServiceLoader.load(Nd4jBackend.class); try { @@ -183,7 +184,9 @@ public abstract class Nd4jBackend { error = e.getMessage(); } if (!available) { - log.warn("Skipped [{}] backend (unavailable): {}", backend.getClass().getSimpleName(), error); + if(logInit) { + log.warn("Skipped [{}] backend (unavailable): {}", backend.getClass().getSimpleName(), error); + } continue; } @@ -193,7 +196,9 @@ public abstract class Nd4jBackend { e.printStackTrace(); } - log.info("Loaded [{}] backend", backend.getClass().getSimpleName()); + if(logInit) { + log.info("Loaded [{}] backend", backend.getClass().getSimpleName()); + } return backend; } @@ -273,6 +278,8 @@ public abstract class Nd4jBackend { return getClass().getName(); } + public abstract void logBackendInit(); + @SuppressWarnings("serial") public static class NoAvailableBackendException extends Exception { diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java index 37dbfdb05..bf1c0ebc1 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java +++ b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java @@ -200,7 +200,7 @@ public class SameDiffServlet implements ModelServingServlet { map.put(n, mds.getFeatures(cnt++)); } - val output = sdModel.exec(map, orderedOutputNodes); + val output = sdModel.output(map, orderedOutputNodes); val arrays = new INDArray[output.size()]; // now we need to get ordered output arrays, as specified in server constructor