From e18e2dc0140db579e322fd70ea5b916d907d4d9b Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 1 Aug 2019 19:30:58 +1000 Subject: [PATCH] Various ND4J/DL4J fixes (#90) * Deprecate Old*Op instances Signed-off-by: AlexDBlack * #8063 #8054 Broadcast exceptions + cleanup inplace ops Signed-off-by: AlexDBlack * Small fix Signed-off-by: AlexDBlack * Remove bad test condition Signed-off-by: AlexDBlack * #7993 Fix shape function issue in crop_and_resize op Signed-off-by: AlexDBlack * DL4J SameDiff lambda layer fix Signed-off-by: AlexDBlack * #8029 Fix for pnorm backprop math Signed-off-by: AlexDBlack --- .../GradientCheckTestsMasking.java | 12 +- .../nn/layers/mkldnn/MKLDNNConvHelper.java | 2 +- .../nn/layers/samediff/SameDiffLayer.java | 9 +- .../generic/parity_ops/crop_and_resize.cpp | 2 +- .../declarable/helpers/cpu/convolutions.cpp | 2 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 21 +-- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 146 ++---------------- .../transforms/pairwise/arithmetic/AddOp.java | 4 + .../transforms/pairwise/arithmetic/DivOp.java | 4 + .../transforms/pairwise/arithmetic/MulOp.java | 4 + .../pairwise/arithmetic/OldAddOp.java | 5 +- .../pairwise/arithmetic/OldDivOp.java | 5 +- .../pairwise/arithmetic/OldMulOp.java | 5 +- .../pairwise/arithmetic/OldRDivOp.java | 5 +- .../pairwise/arithmetic/OldRSubOp.java | 5 +- .../pairwise/arithmetic/OldSubOp.java | 5 +- .../pairwise/arithmetic/RDivOp.java | 4 + .../pairwise/arithmetic/RSubOp.java | 4 + .../transforms/pairwise/arithmetic/SubOp.java | 4 + .../java/org/nd4j/linalg/api/shape/Shape.java | 35 +++++ .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 9 +- .../linalg/broadcast/BasicBroadcastTests.java | 76 ++++++++- .../CropAndResizeDataSetPreProcessorTest.java | 8 +- 23 files changed, 188 insertions(+), 188 deletions(-) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java index 00bc3e196..c1e97a385 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java @@ -290,18 +290,16 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { int nOut = 2; //1 example, TS length 3 - INDArray mask1 = Nd4j.create(new double[] {1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0}, new int[] {1, nOut, 3}, 'f'); + INDArray mask1 = Nd4j.create(new double[] {1, 0, 0, 1, 0, 1}, new int[] {1, nOut, 3}, 'f'); //1 example, TS length 1 - INDArray mask2 = Nd4j.create(new double[] {1, 1, 0, 1}, new int[] {1, nOut, 1}, 'f'); + INDArray mask2 = Nd4j.create(new double[] {1, 1}, new int[] {1, nOut, 1}, 'f'); //3 examples, TS length 3 INDArray mask3 = Nd4j.create(new double[] { //With fortran order: dimension 0 (example) changes quickest, followed by dimension 1 (value within time // step) followed by time index (least frequently) - 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, - - 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, - - 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0}, new int[] {3, nOut, 3}, 'f'); + 1, 0, 1, 0, 1, 1, + 0, 1, 1, 1, 1, 0, + 1, 1, 1, 0, 0, 1,}, new int[] {3, nOut, 3}, 'f'); INDArray[] labelMasks = new INDArray[] {mask1, mask2, mask3}; ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java index 2884f4ced..244f7c1fc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java @@ -127,7 +127,7 @@ public class MKLDNNConvHelper implements ConvolutionHelper { outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation } - if(context == null || true){ + if(context == null ){ context = Nd4j.getExecutioner().buildContext(); context.setIArguments(kernel[0], kernel[1], strides[0], strides[1], diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java index 912bc45a8..5df7afd72 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java @@ -159,7 +159,14 @@ public class SameDiffLayer extends AbstractLayer { g.gradientForVariable().put(s, dl4jGrad); } - dLdIn = sameDiff.grad(INPUT_KEY).getArr(); + SDVariable v = sameDiff.grad(INPUT_KEY); + dLdIn = v.getArr(); + + if(dLdIn == null && fn.getGradPlaceholderName().equals(v.getVarName())){ + //Edge case with lambda layers like identity: SameDiff doesn't store the placeholders + // So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here + dLdIn = epsilon; + } } //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/crop_and_resize.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/crop_and_resize.cpp index 09c7b9579..f1fb9b2c5 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/crop_and_resize.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/crop_and_resize.cpp @@ -56,7 +56,7 @@ namespace nd4j { } DECLARE_SHAPE_FN(crop_and_resize) { - auto in = inputShape->at(1); + auto in = inputShape->at(0); Nd4jLong outputShape[4]; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp index 14d6a95ca..22e7d4d2b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp @@ -2014,7 +2014,7 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d( for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - pgI[kh + kw] += valO * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kh + kw]), extraParam0 - 1.f); + pgI[kh + kw] += valO * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kh + kw]), extraParam0 - 1.f) * nd4j::math::nd4j_sgn(pIn[kh + kw]); } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index fd46da910..2e6ed9ca1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -16,22 +16,15 @@ package org.nd4j.autodiff.samediff; -import static org.nd4j.autodiff.util.TrainingUtils.getSingleOutput; -import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; - import com.google.common.collect.HashBasedTable; import com.google.common.collect.Table; import com.google.common.primitives.Ints; import com.google.flatbuffers.FlatBufferBuilder; import com.rits.cloning.Cloner; import com.rits.cloning.IFastCloner; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import lombok.*; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; -import org.apache.commons.io.output.CloseShieldOutputStream; import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; @@ -49,7 +42,6 @@ import org.nd4j.base.Preconditions; import org.nd4j.evaluation.IEvaluation; import org.nd4j.graph.*; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; -import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.factory.DataBufferFactory; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -59,7 +51,6 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.impl.controlflow.If; import org.nd4j.linalg.api.ops.impl.controlflow.While; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; @@ -76,7 +67,6 @@ import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -import org.nd4j.linalg.exception.ND4JException; import org.nd4j.linalg.exception.ND4JIllegalArgumentException; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables; @@ -89,11 +79,11 @@ import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.DeviceLocalNDArray; import org.nd4j.linalg.util.ND4JFileUtils; -import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.weightinit.WeightInitScheme; import org.nd4j.weightinit.impl.ConstantInitScheme; import org.nd4j.weightinit.impl.NDArraySupplierInitScheme; import org.nd4j.weightinit.impl.ZeroInitScheme; +import org.tensorflow.framework.GraphDef; import java.io.*; import java.lang.reflect.Method; @@ -101,10 +91,11 @@ import java.nio.ByteBuffer; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; -import java.util.zip.ZipEntry; -import java.util.zip.ZipFile; -import java.util.zip.ZipOutputStream; -import org.tensorflow.framework.GraphDef; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.nd4j.autodiff.util.TrainingUtils.getSingleOutput; +import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; /** * SameDiff is the entrypoint for ND4J's automatic differentiation functionality. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 3e85772cb..c18f70d4b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -3692,7 +3692,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { */ @Override public INDArray divi(INDArray other) { - validateNumericalArray("divi", false); return divi(other, this); } @@ -3706,30 +3705,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public INDArray divi(INDArray other, INDArray result) { validateNumericalArray("divi", false); - if (other.isScalar()) { - return divi(other.getDouble(0), result); - } - - if (isScalar()) { - return other.rdivi(getDouble(0), result); - } - - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - val outShape = Shape.broadcastOutputShape(this.shape(), other.shape()); - Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape())); - - Nd4j.exec(new DivOp(new INDArray[]{this, other}, new INDArray[]{result})); - - return result; - } else if(!Shape.shapeEquals(this.shape(),other.shape())) { - int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape()); - Nd4j.getExecutioner().exec(new BroadcastDivOp(this,other,result,broadcastDimensions)); - return result; - } - - - LinAlgExceptions.assertSameShape(other, result); - Nd4j.getExecutioner().exec(new OldDivOp(this, other, result)); + Shape.assertBroadcastable("divi", this, other, result); + Nd4j.exec(new DivOp(this, other, result)); return result; } @@ -3741,7 +3718,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { */ @Override public INDArray muli(INDArray other) { - validateNumericalArray("muli", false); return muli(other, this); } @@ -3755,29 +3731,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public INDArray muli(INDArray other, INDArray result) { validateNumericalArray("muli", false); - if (other.isScalar()) { - return muli(other.getDouble(0), result); - } - if (isScalar()) { - return other.muli(getDouble(0), result); - } - - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - val outShape = Shape.broadcastOutputShape(this.shape(), other.shape()); - Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape())); - - Nd4j.exec(new MulOp(new INDArray[]{this, other}, new INDArray[]{result})); - - return result; - } else if(!Shape.shapeEquals(this.shape(),other.shape())) { - int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape()); - Nd4j.getExecutioner().exec(new BroadcastMulOp(this,other,result,broadcastDimensions)); - return result; - } - - LinAlgExceptions.assertSameShape(other, result); - - Nd4j.getExecutioner().exec(new OldMulOp(this, other, result)); + Shape.assertBroadcastable("muli", this, other, result); + Nd4j.exec(new MulOp(this, other, result)); return result; } @@ -3802,31 +3757,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public INDArray subi(INDArray other, INDArray result) { validateNumericalArray("subi", false); - if (other.isScalar()) { - return subi(other.getDouble(0), result); - } - if (isScalar()) { - return other.rsubi(getDouble(0), result); - } - - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - val outShape = Shape.broadcastOutputShape(this.shape(), other.shape()); - Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape())); - - Nd4j.exec(new SubOp(new INDArray[]{this, other}, new INDArray[]{result})); - - return result; - } else if(!Shape.shapeEquals(this.shape(),other.shape())) { - int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape()); - Nd4j.getExecutioner().exec(new BroadcastSubOp(this,other,result,broadcastDimensions)); - return result; - } - - - LinAlgExceptions.assertSameShape(other, result); - - - Nd4j.getExecutioner().exec(new OldSubOp(this, other,result)); + Shape.assertBroadcastable("subi", this, other, result); + Nd4j.exec(new SubOp(this, other, result)); return result; } @@ -3851,33 +3783,9 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public INDArray addi(INDArray other, INDArray result) { validateNumericalArray("addi", false); - if (other.isScalar()) { - return this.addi(other.getDouble(0), result); - } - - if (isScalar()) { - return other.addi(getDouble(0), result); - } - - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - val outShape = Shape.broadcastOutputShape(this.shape(), other.shape()); - Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape())); - - Nd4j.exec(new AddOp(new INDArray[]{this, other}, new INDArray[]{result})); - - return result; - } else if(!Shape.shapeEquals(this.shape(),other.shape())) { - int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape()); - result = Nd4j.createUninitialized(this.dataType(), Shape.broadcastOutputShape(this.shape(),other.shape())); - Nd4j.getExecutioner().exec(new BroadcastAddOp(this,other,result,broadcastDimensions)); - return result; - } else { - - LinAlgExceptions.assertSameShape(this, other, result); - - Nd4j.getExecutioner().exec(new OldAddOp(this, other, result)); - return result; - } + Shape.assertBroadcastable("addi", this, other, result); + Nd4j.exec(new AddOp(this, other, result)); + return result; } /** @@ -3954,7 +3862,9 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public INDArray rdivi(INDArray other, INDArray result) { validateNumericalArray("rdivi", false); - return other.divi(this, result); + Shape.assertBroadcastable("rdivi", this, other, result); + Nd4j.exec(new RDivOp(this, other, result)); + return result; } /** @@ -4003,33 +3913,9 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public INDArray rsubi(INDArray other, INDArray result) { validateNumericalArray("rsubi", false); - if (other.isScalar()) { - return this.rsubi(other.getDouble(0), result); - } - - if (isScalar()) { - return other.rsubi(getDouble(0), result); - } - - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - val outShape = Shape.broadcastOutputShape(this.shape(), other.shape()); - Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape())); - - Nd4j.exec(new RSubOp(new INDArray[]{this, other}, new INDArray[]{result})); - - return result; - } else if(!Shape.shapeEquals(this.shape(),other.shape())) { - int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape()); - result = Nd4j.createUninitialized(this.dataType(), Shape.broadcastOutputShape(this.shape(),other.shape())); - Nd4j.getExecutioner().exec(new BroadcastRSubOp(this,other,result,broadcastDimensions)); - return result; - } else { - - LinAlgExceptions.assertSameShape(this, other, result); - - Nd4j.getExecutioner().exec(new OldRSubOp(this, other, result)); - return result; - } + Shape.assertBroadcastable("rsubi", this, other, result); + Nd4j.exec(new RSubOp(this, other, result)); + return result; } /** @@ -6796,6 +6682,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { throw new IllegalStateException("Cannot perform operation " + opName + " on empty array with datatype " + dataType()); } + + @Override public boolean closeable() { if (released || isAttached()) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java index 2d6504f14..672159a3e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java @@ -38,6 +38,10 @@ public class AddOp extends BaseDynamicTransformOp { super(sameDiff, args, inPlace); } + public AddOp(INDArray first, INDArray second, INDArray result){ + this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + } + public AddOp(INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java index 365613a68..b76942e95 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java @@ -37,6 +37,10 @@ public class DivOp extends BaseDynamicTransformOp { super(sameDiff, args, inPlace); } + public DivOp(INDArray first, INDArray second, INDArray result){ + this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + } + public DivOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java index f938c7675..4636f9bc8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java @@ -37,6 +37,10 @@ public class MulOp extends BaseDynamicTransformOp { super(sameDiff, args, inPlace); } + public MulOp(INDArray first, INDArray second, INDArray result){ + this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + } + public MulOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldAddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldAddOp.java index 60a876972..18b4e5912 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldAddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldAddOp.java @@ -28,10 +28,9 @@ import java.util.ArrayList; import java.util.List; /** - * Add operation for two operands - * - * @author Adam Gibson + * @deprecated Use {@link AddOp} */ +@Deprecated public class OldAddOp extends BaseTransformAnyOp { public OldAddOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { super(sameDiff, i_v1, i_v2); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldDivOp.java index f9ba500bb..16f2a7761 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldDivOp.java @@ -28,10 +28,9 @@ import java.util.ArrayList; import java.util.List; /** - * Division operation - * - * @author Adam Gibson + * @deprecated Use {@link DivOp} */ +@Deprecated public class OldDivOp extends BaseTransformAnyOp { public OldDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { super(sameDiff, i_v1, i_v2); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldMulOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldMulOp.java index 12a560867..d1e8b6bc6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldMulOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldMulOp.java @@ -28,10 +28,9 @@ import java.util.ArrayList; import java.util.List; /** - * Multiplication operation - * - * @author Adam Gibson + * @deprecated Use {@link MulOp} */ +@Deprecated public class OldMulOp extends BaseTransformAnyOp { public OldMulOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { super(sameDiff, i_v1, i_v2); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldRDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldRDivOp.java index cbe9a8cc1..9838cfd37 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldRDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldRDivOp.java @@ -28,10 +28,9 @@ import java.util.ArrayList; import java.util.List; /** - * OldReverse Division operation - * - * @author Adam Gibson + * @deprecated Use {@link RDivOp} */ +@Deprecated public class OldRDivOp extends BaseTransformAnyOp { public OldRDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { super(sameDiff, i_v1, i_v2); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldRSubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldRSubOp.java index 444a08295..91d2cd90b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldRSubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldRSubOp.java @@ -26,10 +26,9 @@ import java.util.ArrayList; import java.util.List; /** - * Division operation - * - * @author Adam Gibson + * @deprecated Use {@link RSubOp} */ +@Deprecated public class OldRSubOp extends BaseTransformAnyOp { public OldRSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { super(sameDiff, i_v1, i_v2); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldSubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldSubOp.java index 275ec3270..1f806c34a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldSubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldSubOp.java @@ -28,10 +28,9 @@ import java.util.ArrayList; import java.util.List; /** - * Division operation - * - * @author Adam Gibson + * @deprecated Use {@link SubOp} */ +@Deprecated public class OldSubOp extends BaseTransformAnyOp { public OldSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { super(sameDiff, i_v1, i_v2); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java index 12bfab7c4..d54d91dbc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java @@ -38,6 +38,10 @@ public class RDivOp extends BaseDynamicTransformOp { super(sameDiff, args, inPlace); } + public RDivOp(INDArray first, INDArray second, INDArray result){ + this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + } + public RDivOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java index 7a0588de3..5b6833fd8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java @@ -45,6 +45,10 @@ public class RSubOp extends BaseDynamicTransformOp { this(sameDiff, new SDVariable[]{i_v1, i_v2}, inPlace); } + public RSubOp(INDArray first, INDArray second, INDArray result){ + this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + } + public RSubOp() {} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java index a05fb2791..0d222329e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java @@ -37,6 +37,10 @@ public class SubOp extends BaseDynamicTransformOp { super(sameDiff, args, inPlace); } + public SubOp(INDArray first, INDArray second, INDArray result){ + this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + } + public SubOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java index 76f50d733..442dd0f5f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java @@ -251,6 +251,41 @@ public class Shape { return false; } + /** + * Assert that the broadcast operation {@code result = first.op(second)} is valid, given the + * shapes of first, second, and result.
+ * Throws an exception otherwise + * + * @param op Name of the operation + * @param first First array + * @param second Second array + * @param result Result arrray. + */ + public static void assertBroadcastable(String op, INDArray first, INDArray second, INDArray result){ + long[] fShape = first.shape(); + long[] sShape = second.shape(); + Preconditions.checkState(Shape.areShapesBroadcastable(fShape, sShape), + "Cannot perform operation \"%s\" - shapes are not equal and are not broadcastable." + + "first.shape=%s, second.shape=%s", op, fShape, sShape); + + long[] outShape = Shape.broadcastOutputShape(fShape, sShape); + if (!Arrays.equals(outShape, result.shape())) { + //Two cases + // 1. x.addi(y) + // 2. x.addi(y, z) + + String extra = ""; + if(first == result){ + extra = ".\nIn-place operations like x." + op + "(y) can only be performed when x and y have the same shape," + + " or x and y are broadcastable with x.shape() == broadcastShape(x,y)"; + } + + throw new IllegalStateException("Cannot perform in-place operation \"" + op + "\": result array shape does" + + " not match the broadcast operation output shape: " + Arrays.toString(fShape) + "." + op + "(" + + Arrays.toString(sShape) + ") != " + Arrays.toString(result.shape()) + extra); + } + } + public static long[] broadcastOutputShape(long[] left,long[] right) { if (containsZeros(left)) return left; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 15ad2f517..d37ddb889 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -80,6 +80,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse; import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh; import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -6222,7 +6223,7 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(expected, b.rdiv(2)); assertEquals(expected2, d.rdivColumnVector(c)); - assertEquals(expected, b.rdiv(Nd4j.scalar(2))); + assertEquals(expected, b.rdiv(Nd4j.scalar(2.0))); assertEquals(expected, b.rdivColumnVector(Nd4j.scalar(2))); } @@ -7958,7 +7959,11 @@ public class Nd4jTestsC extends BaseNd4jTest { c.addOutputArgument(out); Nd4j.getExecutioner().exec(c); - assertEquals(Nd4j.createFromArray(1f, 3f, 4f), out); + List l = c.calculateOutputShape(); + + System.out.println(Arrays.toString(l.get(0).getShape())); + + //from [4,4,3] to [2,4,6] then crop to [2,4,5] } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java index 8d99cbd61..d9057c95a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java @@ -30,6 +30,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; /** * @author raver119@gmail.com @@ -122,42 +123,42 @@ public class BasicBroadcastTests extends BaseNd4jTest { assertEquals(e, z); } - @Test(expected = IllegalArgumentException.class) + @Test(expected = IllegalStateException.class) public void basicBroadcastFailureTest_1() { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val z = x.subi(y); } - @Test(expected = IllegalArgumentException.class) + @Test(expected = IllegalStateException.class) public void basicBroadcastFailureTest_2() { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val z = x.divi(y); } - @Test(expected = IllegalArgumentException.class) + @Test(expected = IllegalStateException.class) public void basicBroadcastFailureTest_3() { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val z = x.muli(y); } - @Test(expected = IllegalArgumentException.class) + @Test(expected = IllegalStateException.class) public void basicBroadcastFailureTest_4() { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val z = x.addi(y); } - @Test(expected = IllegalArgumentException.class) + @Test(expected = IllegalStateException.class) public void basicBroadcastFailureTest_5() { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val z = x.rsubi(y); } - @Test(expected = IllegalArgumentException.class) + @Test(expected = IllegalStateException.class) public void basicBroadcastFailureTest_6() { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); @@ -206,7 +207,7 @@ public class BasicBroadcastTests extends BaseNd4jTest { assertEquals(y, z); } - @Test(expected = IllegalArgumentException.class) + @Test(expected = IllegalStateException.class) public void emptyBroadcastTest_2() { val x = Nd4j.create(DataType.FLOAT, 1, 2); val y = Nd4j.create(DataType.FLOAT, 0, 2); @@ -226,6 +227,67 @@ public class BasicBroadcastTests extends BaseNd4jTest { assertEquals(y, z); } + + @Test + public void testValidInvalidBroadcast(){ + INDArray x = Nd4j.rand(3,1); + INDArray y = Nd4j.create(3, 4); + + x.add(y); + y.addi(x); + try { + x.addi(y); + } catch (Exception e){ + String s = e.getMessage(); + assertTrue(s, s.contains("broadcast") && s.contains("shape")); + } + + x.sub(y); + y.subi(x); + try { + x.subi(y); + } catch (Exception e){ + String s = e.getMessage(); + assertTrue(s, s.contains("broadcast") && s.contains("shape")); + } + + x.mul(y); + y.muli(x); + try { + x.muli(y); + } catch (Exception e){ + String s = e.getMessage(); + assertTrue(s, s.contains("broadcast") && s.contains("shape")); + } + + x.div(y); + y.divi(x); + try { + x.divi(y); + } catch (Exception e){ + String s = e.getMessage(); + assertTrue(s, s.contains("broadcast") && s.contains("shape")); + } + + x.rsub(y); + y.rsubi(x); + try { + x.rsubi(y); + } catch (Exception e){ + String s = e.getMessage(); + assertTrue(s, s.contains("broadcast") && s.contains("shape")); + } + + x.rdiv(y); + y.rdivi(x); + try { + x.rdivi(y); + } catch (Exception e){ + String s = e.getMessage(); + assertTrue(s, s.contains("broadcast") && s.contains("shape")); + } + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java index 63abfffcd..904484d5f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java @@ -7,8 +7,7 @@ import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; public class CropAndResizeDataSetPreProcessorTest { @@ -93,10 +92,7 @@ public class CropAndResizeDataSetPreProcessorTest { // Assert INDArray results = ds.getFeatures(); long[] shape = results.shape(); - assertEquals(1, shape[0]); - assertEquals(4, shape[1]); - assertEquals(3, shape[2]); - assertEquals(3, shape[3]); + assertArrayEquals(new long[]{1, 4, 3, 3}, shape); // Test a few values assertEquals(55.0, results.getDouble(0, 0, 0, 0), 0.0);