From b95417f7c58fca7ee29681925e1aa8815f811ed9 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 30 Jul 2019 00:27:38 +1000 Subject: [PATCH] Various ND4J/DL4J fixes and improvements (#87) * Reshape and reallocate - small fixes Signed-off-by: AlexDBlack * Reshape and reallocate - small fixes Signed-off-by: AlexDBlack * #6488 ElementWiseVertex broadcast support Signed-off-by: AlexDBlack * Constructors and broadcast supported it Transforms.max/min Signed-off-by: AlexDBlack * #8054 ElementWiseVertex now supports broadcast inputs Signed-off-by: AlexDBlack * #8057 Nd4j.create overload dtype fix Signed-off-by: AlexDBlack * #7551 ND4J Shape validation fix Signed-off-by: AlexDBlack --- .../GradientCheckTestsComputationGraph.java | 55 ++++- .../nn/conf/graph/ElementWiseVertexTest.java | 5 + .../conf/ComputationGraphConfiguration.java | 2 +- .../graph/vertex/impl/ElementWiseVertex.java | 197 +++++++++++++++--- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 3 +- .../api/ops/impl/transforms/custom/Max.java | 4 + .../api/ops/impl/transforms/custom/Min.java | 4 + .../api/ops/impl/transforms/same/Max.java | 2 +- .../api/ops/impl/transforms/same/Min.java | 2 +- .../java/org/nd4j/linalg/factory/Nd4j.java | 19 +- .../linalg/ops/transforms/Transforms.java | 32 ++- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 66 +++++- .../linalg/api/buffer/BaseDataBuffer.java | 3 +- 13 files changed, 346 insertions(+), 48 deletions(-) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java index 1eb893e3f..623158c68 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java @@ -16,6 +16,7 @@ package org.deeplearning4j.gradientcheck; +import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; @@ -53,9 +54,9 @@ import java.util.Arrays; import java.util.Map; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; +@Slf4j public class GradientCheckTestsComputationGraph extends BaseDL4JTest { public static final boolean PRINT_RESULTS = true; @@ -287,6 +288,56 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { } } + @Test + public void testElementWiseVertexBroadcast(){ + + ElementWiseVertex.Op[] ops = + new ElementWiseVertex.Op[] {ElementWiseVertex.Op.Add, ElementWiseVertex.Op.Average, + ElementWiseVertex.Op.Subtract, ElementWiseVertex.Op.Max, ElementWiseVertex.Op.Product}; + + for(boolean firstSmaller : new boolean[]{false, true}) { + for (ElementWiseVertex.Op op : ops) { + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .updater(new NoOp()) + .dataType(DataType.DOUBLE) + .activation(Activation.TANH) + .seed(12345) + .graphBuilder() + .addInputs("in") + .setOutputs("out") + .layer("l1", new DenseLayer.Builder().nIn(3).nOut(firstSmaller ? 1 : 3).build(), "in") //[mb,3] + .layer("l2", new DenseLayer.Builder().nIn(3).nOut(firstSmaller ? 3 : 1).build(), "in") //[mb,1] + .addVertex("ew", new ElementWiseVertex(op), "l1", "l2") + .layer("out", new OutputLayer.Builder().nIn(3).nOut(2).lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).build(), "ew") + .build(); + + ComputationGraph graph = new ComputationGraph(conf); + graph.init(); + + for (int mb : new int[]{1, 5}) { + String msg = (firstSmaller ? "first smaller, " : "second smaller, ") + "mb=" + mb + ", op=" + op; + + log.info("Test: {}", msg); + + INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3); + + INDArray out = graph.outputSingle(in); + assertArrayEquals(new long[]{mb, 2}, out.shape()); + + INDArray labels = TestUtils.randomOneHot(mb, 2); + + graph.fit(new DataSet(in, labels)); + + boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, + DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{in}, + new INDArray[]{labels}); + assertTrue(msg, gradOK); + TestUtils.testModelSerialization(graph); + } + } + } + } + @Test public void testCnnDepthMerge() { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java index afae4b1dc..5bbb8846d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.conf.graph; import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -34,6 +35,7 @@ import org.nd4j.linalg.activations.impl.ActivationTanH; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution; +import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; @@ -42,6 +44,8 @@ import org.nd4j.linalg.primitives.Pair; import java.util.Map; +import static org.junit.Assert.assertArrayEquals; + /** * Created by binesh on 6/14/2017. */ @@ -690,6 +694,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest { Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon); } + private static double mse(INDArray output, INDArray target) { double mse_expect = Transforms.pow(output.sub(target), 2.0).sumNumber().doubleValue() / (output.columns() * output.rows()); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java index a259c7b2b..6cd8f06b3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java @@ -350,7 +350,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { "Use .addInputs(String...) to label (and give an ordering to) the network inputs"); } if ((networkOutputs == null || networkOutputs.isEmpty()) && !allowNoOutput) { - throw new IllegalStateException("Invalid configuration: network has no outputs." + + throw new IllegalStateException("Invalid configuration: network has no outputs. " + "Use .setOutput(String...) to specify (and give an ordering to) the output vertices, " + "or use allowNoOutputs(true) to disable this check"); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java index d9fe6acee..b22aae451 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java @@ -27,15 +27,20 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo; import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldSubOp; +import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Conditions; +import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import java.util.Arrays; + /** An ElementWiseVertex is used to combine the activations of two or more layer in an element-wise manner
* For example, the activations may be combined by addition, subtraction or multiplication or by selecting the maximum. * Addition, Average, Product and Max may use an arbitrary number of input arrays. Note that in the case of subtraction, only two inputs may be used. @@ -80,17 +85,44 @@ public class ElementWiseVertex extends BaseGraphVertex { if (inputs.length == 1) return workspaceMgr.dup(ArrayType.ACTIVATIONS, inputs[0]); + boolean isBc = false; + for(int i=1; i(null, new INDArray[] {workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon)}); + boolean broadcastCase = false; + for( int i=1; i input 0 backprops epsilon, input 1 backprops epsilon.sum(1,keepDim=true) + if(inputs[i].equalShapes(epsilon)){ + out[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon); + } else { + int[] bcDim = Shape.getBroadcastDimensions(inputs[i].shape(), epsilon.shape()); + try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)){ + out[i] = epsilon.sum(true, bcDim); + } + } + } + } return new Pair<>(null, out); case Average: INDArray[] outAverage = new INDArray[nInForwardPass]; try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)){ - for (int i = 0; i < nInForwardPass; i++) - outAverage[i] = epsilon.div(nInForwardPass); + for (int i = 0; i < nInForwardPass; i++) { + if(inputs[i].equalShapes(epsilon)){ + outAverage[i] = epsilon.div(nInForwardPass); + } else { + int[] bcDim = Shape.getBroadcastDimensions(inputs[i].shape(), epsilon.shape()); + outAverage[i] = epsilon.div(nInForwardPass).sum(true, bcDim); + } + } } return new Pair<>(null, outAverage); case Subtract: INDArray[] out2 = new INDArray[2]; - out2[0] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon); - out2[1] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon).negi(); + if(!broadcastCase){ + out2[0] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon); + out2[1] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon).negi(); + } else { + if(inputs[0].equalShapes(epsilon)){ + //Second input is smaller/broadcast + out2[0] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon); + int[] bcDim = Shape.getBroadcastDimensions(inputs[1].shape(), epsilon.shape()); + try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)) { + out2[1] = epsilon.sum(true, bcDim).negi(); + } + } else { + //First input is smaller/broadcast + int[] bcDim = Shape.getBroadcastDimensions(inputs[0].shape(), epsilon.shape()); + try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)) { + out2[0] = epsilon.sum(true, bcDim); + } + out2[1] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon).negi(); + } + } return new Pair<>(null, out2); case Product: INDArray[] out_product = new INDArray[nInForwardPass]; + INDArray[] inBc = inputs; + if(broadcastCase){ + inBc = new INDArray[inputs.length]; + for( int i=0; i(null, out_product); case Max: INDArray[] outMax = new INDArray[nInForwardPass]; INDArray maxIndices = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, DataType.INT, epsilon.shape(), epsilon.ordering()); + + INDArray[] bcIn = inputs; + if(broadcastCase){ + //Broadcast to right shape... + bcIn = new INDArray[inputs.length]; + for( int i=0; i(null, outMax); default: diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index e77dd7fd0..3e85772cb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -2476,6 +2476,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { // length/data.length can be different in case of Threshold conversion if(isEmpty() || isS()) return false; + return Shape.offset(jvmShapeInfo.javaShapeInformation) > 0 || (length() < data().length() && data.dataType() != DataType.INT) || data().originalDataBuffer() != null; @@ -4577,7 +4578,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { return ret; } else { INDArray ret = this.dup(order); - return ret.reshape(order, shape); + return Nd4j.create(ret.data(), shape); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java index d0de5fb21..6c877f96d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java @@ -44,6 +44,10 @@ public class Max extends BaseDynamicTransformOp { super(sameDiff, args, inPlace); } + public Max( INDArray first, INDArray second, INDArray out){ + super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out}); + } + public Max( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java index 97beae406..73bfbacc7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java @@ -44,6 +44,10 @@ public class Min extends BaseDynamicTransformOp { super(sameDiff, args, inPlace); } + public Min( INDArray first, INDArray second, INDArray out){ + super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out}); + } + public Min( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java index 2816d4e60..db682174c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java @@ -26,7 +26,7 @@ import java.util.Collections; import java.util.List; /** - * Calculate the absolute minimum over a vector + * Calculate the maximum value between two arrays in an elementwise fashion, broadcasting if required * * @author raver119@gmail.com */ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java index eb24acdef..6585ace19 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java @@ -26,7 +26,7 @@ import java.util.Collections; import java.util.List; /** - * Calculate the absolute minimum over a vector + * Calculate the minimum value between two arrays in an elementwise fashion, broadcasting if required * * @author raver119@gmail.com */ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index b82a9e19e..a63d6c43a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -538,7 +538,7 @@ public class Nd4j { public static INDArray create(int[] sliceShape, float[]... arrays) { //TODO: Remove duplicate code. int slices = arrays.length; - INDArray ret = Nd4j.create(ArrayUtil.combine(new int[] {slices}, sliceShape)); + INDArray ret = Nd4j.createUninitialized(DataType.FLOAT, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape))); for (int i = 0; i < ret.slices(); i++) ret.putSlice(i, Nd4j.create(arrays[i]).reshape(ArrayUtil.toLongArray(sliceShape))); return ret; @@ -572,7 +572,7 @@ public class Nd4j { */ public static INDArray create(int[] sliceShape, double[]... arrays) { int slices = arrays.length; - INDArray ret = Nd4j.create(ArrayUtil.combine(new int[] {slices}, sliceShape)); + INDArray ret = Nd4j.createUninitialized(DataType.DOUBLE, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape))); for (int i = 0; i < ret.slices(); i++) ret.putSlice(i, Nd4j.create(arrays[i]).reshape(ArrayUtil.toLongArray(sliceShape))); return ret; @@ -3984,6 +3984,7 @@ public class Nd4j { * @return the created ndarray. */ public static INDArray create(int[] data, long[] shape, DataType type) { + checkShapeValues(data.length, shape); return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace()); } @@ -3991,6 +3992,7 @@ public class Nd4j { * See {@link #create(int[], long[], DataType)} */ public static INDArray create(long[] data, long[] shape, DataType type) { + checkShapeValues(data.length, shape); return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace()); } @@ -3998,6 +4000,7 @@ public class Nd4j { * See {@link #create(int[], long[], DataType)} */ public static INDArray create(double[] data, long[] shape, DataType type) { + checkShapeValues(data.length, shape); return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace()); } @@ -4005,6 +4008,7 @@ public class Nd4j { * See {@link #create(int[], long[], DataType)} */ public static INDArray create(float[] data, long[] shape, DataType type) { + checkShapeValues(data.length, shape); return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace()); } @@ -4012,6 +4016,7 @@ public class Nd4j { * See {@link #create(int[], long[], DataType)} */ public static INDArray create(short[] data, long[] shape, DataType type) { + checkShapeValues(data.length, shape); return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace()); } @@ -4019,6 +4024,7 @@ public class Nd4j { * See {@link #create(int[], long[], DataType)} */ public static INDArray create(byte[] data, long[] shape, DataType type) { + checkShapeValues(data.length, shape); return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace()); } @@ -4026,6 +4032,7 @@ public class Nd4j { * See {@link #create(int[], long[], DataType)} */ public static INDArray create(boolean[] data, long[] shape, DataType type) { + checkShapeValues(data.length, shape); return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace()); } @@ -5165,17 +5172,17 @@ public class Nd4j { protected static void checkShapeValues(int length, int... shape) { checkShapeValues(shape); - if (ArrayUtil.prodLong(shape) > length) + if (ArrayUtil.prodLong(shape) != length && !(length == 1 && shape.length == 0)) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) - + " doesn't match data length: " + length); + + " doesn't match data length: " + length + " - prod(shape) must equal the number of values provided"); } protected static void checkShapeValues(int length, long... shape) { checkShapeValues(shape); - if (ArrayUtil.prodLong(shape) > length) + if (ArrayUtil.prodLong(shape) != length && !(length == 1 && shape.length == 0)) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) - + " doesn't match data length: " + length); + + " doesn't match data length: " + length + " - prod(shape) must equal the number of values provided"); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java index e676197ee..57660b8d7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java @@ -45,9 +45,11 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor; import org.nd4j.linalg.api.ops.impl.transforms.same.*; import org.nd4j.linalg.api.ops.impl.transforms.strict.*; import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.inverse.InvertMatrix; +import java.util.Arrays; import java.util.List; /** @@ -858,11 +860,11 @@ public class Transforms { * @return */ public static INDArray max(INDArray first, INDArray second, boolean dup) { - INDArray result = first; - if (dup) { - result = first.ulike(); - } - return exec(new OldMax(first, second, result)); + long[] outShape = broadcastResultShape(first, second); //Also validates + Preconditions.checkState(dup || Arrays.equals(outShape, first.shape()), "Cannot do inplace max operation when first input is not equal to result shape (%ndShape vs. result %s)", + first, outShape); + INDArray out = dup ? Nd4j.create(first.dataType(), outShape) : first; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(first, second, out))[0]; } /** @@ -908,10 +910,11 @@ public class Transforms { * @return */ public static INDArray min(INDArray first, INDArray second, boolean dup) { - if (dup) { - first = first.dup(); - } - return exec(new OldMin(second, first, first)); + long[] outShape = broadcastResultShape(first, second); //Also validates + Preconditions.checkState(dup || Arrays.equals(outShape, first.shape()), "Cannot do inplace min operation when first input is not equal to result shape (%ndShape vs. result %s)", + first, outShape); + INDArray out = dup ? Nd4j.create(first.dataType(), outShape) : first; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(first, second, out))[0]; } /** @@ -1179,4 +1182,15 @@ public class Transforms { } } + + protected static long[] broadcastResultShape(INDArray first, INDArray second){ + if(first.equalShapes(second)){ + return first.shape(); + } else if(Shape.areShapesBroadcastable(first.shape(), second.shape())){ + return Shape.broadcastOutputShape(first.shape(), second.shape()); + } else { + throw new IllegalStateException("Array shapes are not broadcastable: " + Arrays.toString(first.shape()) + + " vs. " + Arrays.toString(second.shape())); + } + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index ab2685be2..246b0a8ef 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -2699,6 +2699,21 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(expected, actual); } + @Test + public void testBroadcastDiv2(){ + INDArray arr = Nd4j.ones(DataType.DOUBLE, 1, 64, 125, 125).muli(2); + INDArray vec = Nd4j.ones(DataType.DOUBLE, 64).muli(2); + + INDArray exp = Nd4j.ones(DataType.DOUBLE, 1, 64, 125, 125); + INDArray out = arr.like(); + + for( int i=0; i<10; i++ ) { + out.assign(0.0); + Nd4j.getExecutioner().exec(new BroadcastDivOp(arr, vec, out, 1)); + assertEquals(exp, out); + } + } + @Test public void testBroadcastMult() { @@ -7417,7 +7432,8 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray arr1a = Nd4j.create(new long[]{2,3}, 'c').get(NDArrayIndex.all(), NDArrayIndex.interval(0,2)); INDArray arr3 = arr1a.reshape('c', false, 4,1); - assertFalse(arr3.isView()); //Should be copy + boolean isView = arr3.isView(); + assertFalse(isView); //Should be copy try{ INDArray arr4 = arr1a.reshape('c', true, 4,1); @@ -7861,6 +7877,54 @@ public class Nd4jTestsC extends BaseNd4jTest { final INDArray arr2 = arr1.reshape(3,1); assertEquals("Incorrect type!", DataType.FLOAT, arr1.mmul(arr2).dataType()); } + + + @Test + public void testCreateDtypes() { + int[] sliceShape = new int[] {9}; + float[] arrays = new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; + double [] arrays_double = new double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + + INDArray x = Nd4j.create( sliceShape, arrays, arrays ); + assertEquals(DataType.FLOAT, x.dataType()); + + INDArray xd = Nd4j.create( sliceShape, arrays_double, arrays_double ); + assertEquals(DataType.DOUBLE, xd.dataType()); + } + + + @Test + public void testCreateShapeValidation(){ + try { + Nd4j.create(new double[]{1, 2, 3}, new int[]{1, 1}); + fail(); + } catch (Exception t){ + assertTrue(t.getMessage().contains("length")); + } + + try { + Nd4j.create(new float[]{1, 2, 3}, new int[]{1, 1}); + fail(); + } catch (Exception t){ + assertTrue(t.getMessage().contains("length")); + } + + try { + Nd4j.create(new byte[]{1, 2, 3}, new long[]{1, 1}, DataType.BYTE); + fail(); + } catch (Exception t){ + assertTrue(t.getMessage().contains("length")); + } + + try { + Nd4j.create(new double[]{1, 2, 3}, new int[]{1, 1}, 'c'); + fail(); + } catch (Exception t){ + assertTrue(t.getMessage().contains("length")); + } + } + + /////////////////////////////////////////////////////// protected static void fillJvmArray3D(float[][][] arr) { int cnt = 1; diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index ee83e82e5..82fff8437 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -2601,7 +2601,8 @@ public abstract class BaseDataBuffer implements DataBuffer { } Pointer.memcpy(pointer, oldPointer, this.length() * getElementSize()); - //this.underlyingLength = length; + this.underlyingLength = length; + this.length = length; return this; }