Various ND4J/DL4J fixes and improvements (#87)

* Reshape and reallocate - small fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Reshape and reallocate - small fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #6488 ElementWiseVertex broadcast support

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Constructors and broadcast supported it Transforms.max/min

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8054 ElementWiseVertex now supports broadcast inputs

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8057 Nd4j.create overload dtype fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7551 ND4J Shape validation fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-07-30 00:27:38 +10:00 committed by AlexDBlack
parent aa4af2c36d
commit b95417f7c5
13 changed files with 346 additions and 48 deletions

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.gradientcheck;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
@ -53,9 +54,9 @@ import java.util.Arrays;
import java.util.Map;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
@Slf4j
public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
public static final boolean PRINT_RESULTS = true;
@ -287,6 +288,56 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
}
}
@Test
public void testElementWiseVertexBroadcast(){
ElementWiseVertex.Op[] ops =
new ElementWiseVertex.Op[] {ElementWiseVertex.Op.Add, ElementWiseVertex.Op.Average,
ElementWiseVertex.Op.Subtract, ElementWiseVertex.Op.Max, ElementWiseVertex.Op.Product};
for(boolean firstSmaller : new boolean[]{false, true}) {
for (ElementWiseVertex.Op op : ops) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new NoOp())
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.seed(12345)
.graphBuilder()
.addInputs("in")
.setOutputs("out")
.layer("l1", new DenseLayer.Builder().nIn(3).nOut(firstSmaller ? 1 : 3).build(), "in") //[mb,3]
.layer("l2", new DenseLayer.Builder().nIn(3).nOut(firstSmaller ? 3 : 1).build(), "in") //[mb,1]
.addVertex("ew", new ElementWiseVertex(op), "l1", "l2")
.layer("out", new OutputLayer.Builder().nIn(3).nOut(2).lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).build(), "ew")
.build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
for (int mb : new int[]{1, 5}) {
String msg = (firstSmaller ? "first smaller, " : "second smaller, ") + "mb=" + mb + ", op=" + op;
log.info("Test: {}", msg);
INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3);
INDArray out = graph.outputSingle(in);
assertArrayEquals(new long[]{mb, 2}, out.shape());
INDArray labels = TestUtils.randomOneHot(mb, 2);
graph.fit(new DataSet(in, labels));
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{in},
new INDArray[]{labels});
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
}
}
}
}
@Test
public void testCnnDepthMerge() {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.conf.graph;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -34,6 +35,7 @@ import org.nd4j.linalg.activations.impl.ActivationTanH;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
@ -42,6 +44,8 @@ import org.nd4j.linalg.primitives.Pair;
import java.util.Map;
import static org.junit.Assert.assertArrayEquals;
/**
* Created by binesh on 6/14/2017.
*/
@ -690,6 +694,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon);
}
private static double mse(INDArray output, INDArray target) {
double mse_expect = Transforms.pow(output.sub(target), 2.0).sumNumber().doubleValue()
/ (output.columns() * output.rows());

View File

@ -27,15 +27,20 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldSubOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.util.Arrays;
/** An ElementWiseVertex is used to combine the activations of two or more layer in an element-wise manner<br>
* For example, the activations may be combined by addition, subtraction or multiplication or by selecting the maximum.
* Addition, Average, Product and Max may use an arbitrary number of input arrays. Note that in the case of subtraction, only two inputs may be used.
@ -80,17 +85,44 @@ public class ElementWiseVertex extends BaseGraphVertex {
if (inputs.length == 1)
return workspaceMgr.dup(ArrayType.ACTIVATIONS, inputs[0]);
boolean isBc = false;
for(int i=1; i<inputs.length; i++ ){
if(!inputs[0].equalShapes(inputs[i])){
isBc = true;
break;
}
}
long[] outShape;
if(!isBc){
outShape = inputs[0].shape();
} else {
outShape = Shape.broadcastOutputShape(inputs[0].shape(), inputs[1].shape());
for( int i=2; i<inputs.length; i++ ){
outShape = Shape.broadcastOutputShape(outShape, inputs[i].shape());
}
}
switch (op) {
case Add:
INDArray sum = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, inputs[0].shape());
INDArray sum = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, outShape);
if(isBc && !Arrays.equals(outShape, inputs[0].shape())){
Nd4j.exec(new BroadcastTo(inputs[0], outShape, sum));
} else {
sum.assign(inputs[0]);
}
for (int i = 1; i < inputs.length; i++) {
sum.addi(inputs[i].castTo(dataType));
}
return sum;
case Average:
INDArray average = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, inputs[0].shape());
INDArray average = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, outShape);
if(isBc && !Arrays.equals(outShape, inputs[0].shape())){
Nd4j.exec(new BroadcastTo(inputs[0], outShape, average));
} else {
average.assign(inputs[0]);
}
for (int i = 1; i < inputs.length; i++) {
average.addi(inputs[i].castTo(dataType));
}
@ -98,15 +130,28 @@ public class ElementWiseVertex extends BaseGraphVertex {
case Subtract:
if (inputs.length != 2)
throw new IllegalArgumentException("ElementWise subtraction only supports 2 inputs");
return Nd4j.getExecutioner().exec(new OldSubOp(inputs[0], inputs[1], workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, inputs[0].dataType(), inputs[0].shape())));
return Nd4j.exec(new SubOp(inputs, new INDArray[]{workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, inputs[0].dataType(), outShape)}))[0];
case Product:
INDArray product = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, inputs[0].shape());
INDArray product = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, outShape);
if(isBc && !Arrays.equals(outShape, inputs[0].shape())){
Nd4j.exec(new BroadcastTo(inputs[0], outShape, product));
} else {
product.assign(inputs[0]);
}
for (int i = 1; i < inputs.length; i++) {
product.muli(inputs[i].castTo(dataType));
}
return product;
case Max:
boolean isBroadcast = false;
for(int i=1; i<inputs.length; i++ ){
isBroadcast |= !inputs[0].equalShapes(inputs[i]);
if(isBroadcast)
break;
}
if(!isBroadcast) {
INDArray max = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, inputs[0].dataType(), inputs[0].shape(), inputs[0].ordering());
CustomOp op = DynamicCustomOp.builder("mergemax")
.addInputs(inputs)
@ -115,6 +160,19 @@ public class ElementWiseVertex extends BaseGraphVertex {
.build();
Nd4j.getExecutioner().exec(op);
return max;
} else {
//AB 20190729 mergemax doesn't support broadcast at this point
if(inputs.length == 1){
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, inputs[0]);
} else {
INDArray max = Transforms.max(inputs[0], inputs[1], true);
for( int i=2; i<inputs.length; i++ ){
max = Transforms.max(max, inputs[i], false);
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, max);
}
}
default:
throw new UnsupportedOperationException("Unknown op: " + this.op);
}
@ -128,40 +186,119 @@ public class ElementWiseVertex extends BaseGraphVertex {
if (nInForwardPass == 1)
return new Pair<>(null, new INDArray[] {workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon)});
boolean broadcastCase = false;
for( int i=1; i<nInForwardPass; i++ ){
broadcastCase |= !inputs[0].equalShapes(inputs[i]);
}
switch (op) {
case Add:
//If x=sum_i a_i then dL/da_i = dL/dx * dx/da_i = dL/dx
INDArray[] out = new INDArray[nInForwardPass];
for (int i = 0; i < nInForwardPass; i++)
for (int i = 0; i < nInForwardPass; i++) {
if(!broadcastCase) {
//Standard case
out[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon);
} else {
//For broadcast case, we need to sum along the broadcast dimensions
//So if [mb,3]+[mb,1] -> input 0 backprops epsilon, input 1 backprops epsilon.sum(1,keepDim=true)
if(inputs[i].equalShapes(epsilon)){
out[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon);
} else {
int[] bcDim = Shape.getBroadcastDimensions(inputs[i].shape(), epsilon.shape());
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)){
out[i] = epsilon.sum(true, bcDim);
}
}
}
}
return new Pair<>(null, out);
case Average:
INDArray[] outAverage = new INDArray[nInForwardPass];
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)){
for (int i = 0; i < nInForwardPass; i++)
for (int i = 0; i < nInForwardPass; i++) {
if(inputs[i].equalShapes(epsilon)){
outAverage[i] = epsilon.div(nInForwardPass);
} else {
int[] bcDim = Shape.getBroadcastDimensions(inputs[i].shape(), epsilon.shape());
outAverage[i] = epsilon.div(nInForwardPass).sum(true, bcDim);
}
}
}
return new Pair<>(null, outAverage);
case Subtract:
INDArray[] out2 = new INDArray[2];
if(!broadcastCase){
out2[0] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon);
out2[1] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon).negi();
} else {
if(inputs[0].equalShapes(epsilon)){
//Second input is smaller/broadcast
out2[0] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon);
int[] bcDim = Shape.getBroadcastDimensions(inputs[1].shape(), epsilon.shape());
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)) {
out2[1] = epsilon.sum(true, bcDim).negi();
}
} else {
//First input is smaller/broadcast
int[] bcDim = Shape.getBroadcastDimensions(inputs[0].shape(), epsilon.shape());
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)) {
out2[0] = epsilon.sum(true, bcDim);
}
out2[1] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon).negi();
}
}
return new Pair<>(null, out2);
case Product:
INDArray[] out_product = new INDArray[nInForwardPass];
INDArray[] inBc = inputs;
if(broadcastCase){
inBc = new INDArray[inputs.length];
for( int i=0; i<inputs.length; i++ ){
if(inputs[i].equalShapes(epsilon)){
inBc[i] = inputs[i];
} else {
inBc[i] = epsilon.ulike();
Nd4j.exec(new BroadcastTo(inputs[i], epsilon.shape(), inBc[i]));
}
}
}
for (int i = 0; i < nInForwardPass; i++) {
out_product[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon);
for (int j = 0; j < nInForwardPass; ++j) {
if (i != j)
out_product[i].muli(inputs[j]);
out_product[i].muli(inBc[j]);
}
if(!inputs[i].equalShapes(epsilon)){
int[] bcDim = Shape.getBroadcastDimensions(inputs[i].shape(), epsilon.shape());
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)) {
out_product[i] = out_product[i].sum(true, bcDim);
}
}
}
return new Pair<>(null, out_product);
case Max:
INDArray[] outMax = new INDArray[nInForwardPass];
INDArray maxIndices = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, DataType.INT, epsilon.shape(), epsilon.ordering());
INDArray[] bcIn = inputs;
if(broadcastCase){
//Broadcast to right shape...
bcIn = new INDArray[inputs.length];
for( int i=0; i<inputs.length; i++ ){
if(inputs[i].equalShapes(epsilon)){
bcIn[i] = inputs[i];
} else {
bcIn[i] = epsilon.ulike();
Nd4j.exec(new BroadcastTo(inputs[i], epsilon.shape(), bcIn[i]));
}
}
}
CustomOp op = DynamicCustomOp.builder("mergemaxindex")
.addInputs(inputs)
.addInputs(bcIn)
.addOutputs(maxIndices)
.callInplace(false)
.build();
@ -172,7 +309,17 @@ public class ElementWiseVertex extends BaseGraphVertex {
//generate a mask with 1s and 0s in the right places and muli with epsilon
MatchConditionTransform nd4jop = new MatchConditionTransform(maxIndices, outMax[i], Conditions.equals(i));
Nd4j.getExecutioner().exec(nd4jop);
outMax[i] = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, outMax[i].castTo(Nd4j.defaultFloatingPointType())).muli(epsilon);
if(broadcastCase && !epsilon.equalShapes(inputs[i])){
//Broadcast for ths input
outMax[i] = outMax[i].castTo(epsilon.dataType()).mul(epsilon);
int[] bcDim = Shape.getBroadcastDimensions(inputs[i].shape(), epsilon.shape());
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)) {
outMax[i] = outMax[i].sum(true, bcDim);
}
} else {
//Standard case
outMax[i] = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, outMax[i].castTo(epsilon.dataType()).muli(epsilon));
}
}
return new Pair<>(null, outMax);
default:

View File

@ -2476,6 +2476,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
// length/data.length can be different in case of Threshold conversion
if(isEmpty() || isS())
return false;
return Shape.offset(jvmShapeInfo.javaShapeInformation) > 0
|| (length() < data().length() && data.dataType() != DataType.INT)
|| data().originalDataBuffer() != null;
@ -4577,7 +4578,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return ret;
} else {
INDArray ret = this.dup(order);
return ret.reshape(order, shape);
return Nd4j.create(ret.data(), shape);
}
}

View File

@ -44,6 +44,10 @@ public class Max extends BaseDynamicTransformOp {
super(sameDiff, args, inPlace);
}
public Max( INDArray first, INDArray second, INDArray out){
super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out});
}
public Max( INDArray[] inputs, INDArray[] outputs) {
super(inputs, outputs);
}

View File

@ -44,6 +44,10 @@ public class Min extends BaseDynamicTransformOp {
super(sameDiff, args, inPlace);
}
public Min( INDArray first, INDArray second, INDArray out){
super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out});
}
public Min( INDArray[] inputs, INDArray[] outputs) {
super(inputs, outputs);
}

View File

@ -26,7 +26,7 @@ import java.util.Collections;
import java.util.List;
/**
* Calculate the absolute minimum over a vector
* Calculate the maximum value between two arrays in an elementwise fashion, broadcasting if required
*
* @author raver119@gmail.com
*/

View File

@ -26,7 +26,7 @@ import java.util.Collections;
import java.util.List;
/**
* Calculate the absolute minimum over a vector
* Calculate the minimum value between two arrays in an elementwise fashion, broadcasting if required
*
* @author raver119@gmail.com
*/

View File

@ -538,7 +538,7 @@ public class Nd4j {
public static INDArray create(int[] sliceShape, float[]... arrays) {
//TODO: Remove duplicate code.
int slices = arrays.length;
INDArray ret = Nd4j.create(ArrayUtil.combine(new int[] {slices}, sliceShape));
INDArray ret = Nd4j.createUninitialized(DataType.FLOAT, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape)));
for (int i = 0; i < ret.slices(); i++)
ret.putSlice(i, Nd4j.create(arrays[i]).reshape(ArrayUtil.toLongArray(sliceShape)));
return ret;
@ -572,7 +572,7 @@ public class Nd4j {
*/
public static INDArray create(int[] sliceShape, double[]... arrays) {
int slices = arrays.length;
INDArray ret = Nd4j.create(ArrayUtil.combine(new int[] {slices}, sliceShape));
INDArray ret = Nd4j.createUninitialized(DataType.DOUBLE, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape)));
for (int i = 0; i < ret.slices(); i++)
ret.putSlice(i, Nd4j.create(arrays[i]).reshape(ArrayUtil.toLongArray(sliceShape)));
return ret;
@ -3984,6 +3984,7 @@ public class Nd4j {
* @return the created ndarray.
*/
public static INDArray create(int[] data, long[] shape, DataType type) {
checkShapeValues(data.length, shape);
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@ -3991,6 +3992,7 @@ public class Nd4j {
* See {@link #create(int[], long[], DataType)}
*/
public static INDArray create(long[] data, long[] shape, DataType type) {
checkShapeValues(data.length, shape);
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@ -3998,6 +4000,7 @@ public class Nd4j {
* See {@link #create(int[], long[], DataType)}
*/
public static INDArray create(double[] data, long[] shape, DataType type) {
checkShapeValues(data.length, shape);
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@ -4005,6 +4008,7 @@ public class Nd4j {
* See {@link #create(int[], long[], DataType)}
*/
public static INDArray create(float[] data, long[] shape, DataType type) {
checkShapeValues(data.length, shape);
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@ -4012,6 +4016,7 @@ public class Nd4j {
* See {@link #create(int[], long[], DataType)}
*/
public static INDArray create(short[] data, long[] shape, DataType type) {
checkShapeValues(data.length, shape);
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@ -4019,6 +4024,7 @@ public class Nd4j {
* See {@link #create(int[], long[], DataType)}
*/
public static INDArray create(byte[] data, long[] shape, DataType type) {
checkShapeValues(data.length, shape);
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@ -4026,6 +4032,7 @@ public class Nd4j {
* See {@link #create(int[], long[], DataType)}
*/
public static INDArray create(boolean[] data, long[] shape, DataType type) {
checkShapeValues(data.length, shape);
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@ -5165,17 +5172,17 @@ public class Nd4j {
protected static void checkShapeValues(int length, int... shape) {
checkShapeValues(shape);
if (ArrayUtil.prodLong(shape) > length)
if (ArrayUtil.prodLong(shape) != length && !(length == 1 && shape.length == 0))
throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape)
+ " doesn't match data length: " + length);
+ " doesn't match data length: " + length + " - prod(shape) must equal the number of values provided");
}
protected static void checkShapeValues(int length, long... shape) {
checkShapeValues(shape);
if (ArrayUtil.prodLong(shape) > length)
if (ArrayUtil.prodLong(shape) != length && !(length == 1 && shape.length == 0))
throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape)
+ " doesn't match data length: " + length);
+ " doesn't match data length: " + length + " - prod(shape) must equal the number of values provided");
}

View File

@ -45,9 +45,11 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
import org.nd4j.linalg.api.ops.impl.transforms.same.*;
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.inverse.InvertMatrix;
import java.util.Arrays;
import java.util.List;
/**
@ -858,11 +860,11 @@ public class Transforms {
* @return
*/
public static INDArray max(INDArray first, INDArray second, boolean dup) {
INDArray result = first;
if (dup) {
result = first.ulike();
}
return exec(new OldMax(first, second, result));
long[] outShape = broadcastResultShape(first, second); //Also validates
Preconditions.checkState(dup || Arrays.equals(outShape, first.shape()), "Cannot do inplace max operation when first input is not equal to result shape (%ndShape vs. result %s)",
first, outShape);
INDArray out = dup ? Nd4j.create(first.dataType(), outShape) : first;
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(first, second, out))[0];
}
/**
@ -908,10 +910,11 @@ public class Transforms {
* @return
*/
public static INDArray min(INDArray first, INDArray second, boolean dup) {
if (dup) {
first = first.dup();
}
return exec(new OldMin(second, first, first));
long[] outShape = broadcastResultShape(first, second); //Also validates
Preconditions.checkState(dup || Arrays.equals(outShape, first.shape()), "Cannot do inplace min operation when first input is not equal to result shape (%ndShape vs. result %s)",
first, outShape);
INDArray out = dup ? Nd4j.create(first.dataType(), outShape) : first;
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(first, second, out))[0];
}
/**
@ -1179,4 +1182,15 @@ public class Transforms {
}
}
protected static long[] broadcastResultShape(INDArray first, INDArray second){
if(first.equalShapes(second)){
return first.shape();
} else if(Shape.areShapesBroadcastable(first.shape(), second.shape())){
return Shape.broadcastOutputShape(first.shape(), second.shape());
} else {
throw new IllegalStateException("Array shapes are not broadcastable: " + Arrays.toString(first.shape()) +
" vs. " + Arrays.toString(second.shape()));
}
}
}

View File

@ -2699,6 +2699,21 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(expected, actual);
}
@Test
public void testBroadcastDiv2(){
INDArray arr = Nd4j.ones(DataType.DOUBLE, 1, 64, 125, 125).muli(2);
INDArray vec = Nd4j.ones(DataType.DOUBLE, 64).muli(2);
INDArray exp = Nd4j.ones(DataType.DOUBLE, 1, 64, 125, 125);
INDArray out = arr.like();
for( int i=0; i<10; i++ ) {
out.assign(0.0);
Nd4j.getExecutioner().exec(new BroadcastDivOp(arr, vec, out, 1));
assertEquals(exp, out);
}
}
@Test
public void testBroadcastMult() {
@ -7417,7 +7432,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray arr1a = Nd4j.create(new long[]{2,3}, 'c').get(NDArrayIndex.all(), NDArrayIndex.interval(0,2));
INDArray arr3 = arr1a.reshape('c', false, 4,1);
assertFalse(arr3.isView()); //Should be copy
boolean isView = arr3.isView();
assertFalse(isView); //Should be copy
try{
INDArray arr4 = arr1a.reshape('c', true, 4,1);
@ -7861,6 +7877,54 @@ public class Nd4jTestsC extends BaseNd4jTest {
final INDArray arr2 = arr1.reshape(3,1);
assertEquals("Incorrect type!", DataType.FLOAT, arr1.mmul(arr2).dataType());
}
@Test
public void testCreateDtypes() {
int[] sliceShape = new int[] {9};
float[] arrays = new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f};
double [] arrays_double = new double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0};
INDArray x = Nd4j.create( sliceShape, arrays, arrays );
assertEquals(DataType.FLOAT, x.dataType());
INDArray xd = Nd4j.create( sliceShape, arrays_double, arrays_double );
assertEquals(DataType.DOUBLE, xd.dataType());
}
@Test
public void testCreateShapeValidation(){
try {
Nd4j.create(new double[]{1, 2, 3}, new int[]{1, 1});
fail();
} catch (Exception t){
assertTrue(t.getMessage().contains("length"));
}
try {
Nd4j.create(new float[]{1, 2, 3}, new int[]{1, 1});
fail();
} catch (Exception t){
assertTrue(t.getMessage().contains("length"));
}
try {
Nd4j.create(new byte[]{1, 2, 3}, new long[]{1, 1}, DataType.BYTE);
fail();
} catch (Exception t){
assertTrue(t.getMessage().contains("length"));
}
try {
Nd4j.create(new double[]{1, 2, 3}, new int[]{1, 1}, 'c');
fail();
} catch (Exception t){
assertTrue(t.getMessage().contains("length"));
}
}
///////////////////////////////////////////////////////
protected static void fillJvmArray3D(float[][][] arr) {
int cnt = 1;

View File

@ -2601,7 +2601,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
}
Pointer.memcpy(pointer, oldPointer, this.length() * getElementSize());
//this.underlyingLength = length;
this.underlyingLength = length;
this.length = length;
return this;
}