From d31197db5fce781240f3db4c53caa1f86afd92c7 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 27 Aug 2019 18:02:41 -0700 Subject: [PATCH] Remove resolvePropertiesFromSameDiffBeforeExecution() (#172) * remove unneeded resolveProperties methods Signed-off-by: Ryan Nett * final fixes, make final to prevent more from being added Signed-off-by: Ryan Nett * gather fix Signed-off-by: Ryan Nett * deprecate DifferentialFunction resolveProps Signed-off-by: Ryan Nett --- .../functions/DifferentialFunction.java | 5 +++- .../ops/impl/layers/convolution/Conv3D.java | 8 ------ .../api/ops/impl/shape/ConfusionMatrix.java | 5 ---- .../linalg/api/ops/impl/shape/Gather.java | 25 ------------------- .../linalg/api/ops/impl/shape/Repeat.java | 9 ------- .../linalg/api/ops/impl/shape/Squeeze.java | 9 +------ .../linalg/api/ops/impl/shape/Transpose.java | 5 ---- .../impl/transforms/HistogramFixedWidth.java | 10 -------- .../linalg/api/ops/impl/transforms/Pad.java | 11 -------- .../segment/UnsortedSegmentMax.java | 8 ------ .../segment/UnsortedSegmentMean.java | 8 ------ .../segment/UnsortedSegmentMin.java | 8 ------ .../segment/UnsortedSegmentProd.java | 8 ------ .../segment/UnsortedSegmentSqrtN.java | 8 ------ .../segment/UnsortedSegmentSum.java | 8 ------ 15 files changed, 5 insertions(+), 130 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 71bbd26ee..34240516f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -547,8 +547,11 @@ public abstract class DifferentialFunction { /** * Resolve properties and arguments right before execution of * this operation. + * + * @deprecated Will be removed in the future. Ops should support array arguments. Should not bs used or overridden. */ - public void resolvePropertiesFromSameDiffBeforeExecution() { + @Deprecated + public final void resolvePropertiesFromSameDiffBeforeExecution() { val properties = sameDiff.propertiesToResolveForFunction(this); val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this); val currentFields = this.propertiesForFunction(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java index d3fe330fd..810974103 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java @@ -268,14 +268,6 @@ public class Conv3D extends DynamicCustomOp { return ret; } - - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if (numIArguments() < 1) { - addArgs(); - } - } - @Override public boolean isConfigProperties() { return true; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java index 231cf5783..6275ce210 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java @@ -66,11 +66,6 @@ public class ConfusionMatrix extends DynamicCustomOp { //Looks like this is implemented in practice using a large collection of discrete ops - not single TF import op? } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public String opName() { return "confusion_matrix"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java index 31718d337..5613cc85f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java @@ -74,7 +74,6 @@ public class Gather extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); - } @Override @@ -82,30 +81,6 @@ public class Gather extends DynamicCustomOp { OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); } - - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { -// super.resolvePropertiesFromSameDiffBeforeExecution(); - if (indices != null && numInputArguments() < 2) { - if (numInputArguments() == 0) { - INDArray a = Nd4j.create(indices, new long[]{indices.length}, new long[]{1}, 'c', DataType.INT); - if (indices.length > 1) - a = a.reshape(indices.length); - else - a = a.reshape(new int[]{}); - - addInputArgument(args()[0].getArr(), a); - } else if (numInputArguments() == 1) { - addInputArgument(Nd4j.create(indices, new long[]{indices.length}, new long[]{1}, 'c', DataType.INT)); - } - - } - - if (numIArguments() < 1) { - addIArgument(jaxis); - } - } - @Override public Map> mappingsForFunction() { Map> ret = new HashMap<>(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java index 02f8f9445..af8940bf4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java @@ -110,15 +110,6 @@ public class Repeat extends DynamicCustomOp { super.initFromOnnx(node, initWith, attributesForNode, graph); } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if (numOutputArguments() < getDescriptor().getNumOutputs()) { - for (val output : outputVariables()) { - addOutputArgument(output.getArr()); - } - } - } - @Override public String onnxName() { return "Repeat"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java index bfd8f58ec..816654d13 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java @@ -41,6 +41,7 @@ public class Squeeze extends DynamicCustomOp { public Squeeze(SameDiff sameDiff, SDVariable arg, int[] squeezeDims) { super(null, sameDiff, new SDVariable[]{arg}); this.squeezeDims = squeezeDims; + addIArgument(squeezeDims); } @Override @@ -53,14 +54,6 @@ public class Squeeze extends DynamicCustomOp { addIArgument(squeezeDims); } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - super.resolvePropertiesFromSameDiffBeforeExecution(); - if (squeezeDims != null && numIArguments() < squeezeDims.length) { - addIArgument(squeezeDims); - } - } - @Override public String opName() { return "squeeze"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java index 2de0a29c5..b9687e598 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java @@ -65,11 +65,6 @@ public class Transpose extends DynamicCustomOp { public Transpose() { } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public Map> mappingsForFunction() { Map> ret = new LinkedHashMap<>(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/HistogramFixedWidth.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/HistogramFixedWidth.java index 8894a87ec..091846db8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/HistogramFixedWidth.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/HistogramFixedWidth.java @@ -70,16 +70,6 @@ public class HistogramFixedWidth extends DynamicCustomOp { //No op - just need the inputs } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3 && iArguments.isEmpty()){ - //Num bins is 3rd array - addIArgument(arg(2).getArr().getInt(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - - @Override public List doDiff(List f1) { throw new UnsupportedOperationException("Not supported"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java index f940ac7a0..72d2823b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java @@ -82,17 +82,6 @@ public class Pad extends DynamicCustomOp { //Constant value is resolved just before execution } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3){ - INDArray arr = arg(2).getArr(); - this.tArguments.clear(); - this.tArguments.add(arr.getDouble(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - - @Override public List doDiff(List i_v) { //Pad backprop: it's basically slice op... diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java index 73d4d281e..6d6798701 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java @@ -54,14 +54,6 @@ public class UnsortedSegmentMax extends DynamicCustomOp { return "UnsortedSegmentMax"; } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3 && iArguments == null || iArguments.size() == 0){ - addIArgument(arg(2).getArr().getInt(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public List doDiff(List gradients){ return Arrays.asList(f().unsortedSegmentMaxBp(arg(0), arg(1), gradients.get(0), numSegments)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java index a110d7225..f51b94218 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java @@ -54,14 +54,6 @@ public class UnsortedSegmentMean extends DynamicCustomOp { return "UnsortedSegmentMean"; } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3 && iArguments == null || iArguments.size() == 0){ - addIArgument(arg(2).getArr().getInt(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public List doDiff(List gradients){ return Arrays.asList(f().unsortedSegmentMeanBp(arg(0), arg(1), gradients.get(0), numSegments)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java index 6666981ff..1b885676e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java @@ -54,14 +54,6 @@ public class UnsortedSegmentMin extends DynamicCustomOp { return "UnsortedSegmentMin"; } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3 && iArguments == null || iArguments.size() == 0){ - addIArgument(arg(2).getArr().getInt(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public List doDiff(List gradients){ return Arrays.asList(f().unsortedSegmentMinBp(arg(0), arg(1), gradients.get(0), numSegments)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java index 36ba70342..b2e254fb7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java @@ -54,14 +54,6 @@ public class UnsortedSegmentProd extends DynamicCustomOp { return "UnsortedSegmentProd"; } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3 && iArguments == null || iArguments.size() == 0){ - addIArgument(arg(2).getArr().getInt(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public List doDiff(List gradients){ return Arrays.asList(f().unsortedSegmentProdBp(arg(0), arg(1), gradients.get(0), numSegments)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java index cfb62450a..ef34e9f81 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java @@ -53,14 +53,6 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp { return "UnsortedSegmentSqrtN"; } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3 && iArguments == null || iArguments.size() == 0){ - addIArgument(arg(2).getArr().getInt(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public List doDiff(List gradients){ return Arrays.asList(f().unsortedSegmentSqrtNBp(arg(0), arg(1), gradients.get(0), numSegments)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java index 3e4c4d2f1..466cc8cf2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java @@ -55,14 +55,6 @@ public class UnsortedSegmentSum extends DynamicCustomOp { return "UnsortedSegmentSum"; } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3 && iArguments == null || iArguments.size() == 0){ - addIArgument(arg(2).getArr().getInt(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public List doDiff(List gradients){ return Arrays.asList(f().unsortedSegmentSumBp(arg(0), arg(1), gradients.get(0), numSegments));