Various ND4J/DL4J fixes and improvements (#87)
* Reshape and reallocate - small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Reshape and reallocate - small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6488 ElementWiseVertex broadcast support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Constructors and broadcast supported it Transforms.max/min Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8054 ElementWiseVertex now supports broadcast inputs Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8057 Nd4j.create overload dtype fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7551 ND4J Shape validation fix Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
aa4af2c36d
commit
b95417f7c5
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.gradientcheck;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.TestUtils;
|
||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||
|
@ -53,9 +54,9 @@ import java.util.Arrays;
|
|||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@Slf4j
|
||||
public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||
|
||||
public static final boolean PRINT_RESULTS = true;
|
||||
|
@ -287,6 +288,56 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testElementWiseVertexBroadcast(){
|
||||
|
||||
ElementWiseVertex.Op[] ops =
|
||||
new ElementWiseVertex.Op[] {ElementWiseVertex.Op.Add, ElementWiseVertex.Op.Average,
|
||||
ElementWiseVertex.Op.Subtract, ElementWiseVertex.Op.Max, ElementWiseVertex.Op.Product};
|
||||
|
||||
for(boolean firstSmaller : new boolean[]{false, true}) {
|
||||
for (ElementWiseVertex.Op op : ops) {
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.updater(new NoOp())
|
||||
.dataType(DataType.DOUBLE)
|
||||
.activation(Activation.TANH)
|
||||
.seed(12345)
|
||||
.graphBuilder()
|
||||
.addInputs("in")
|
||||
.setOutputs("out")
|
||||
.layer("l1", new DenseLayer.Builder().nIn(3).nOut(firstSmaller ? 1 : 3).build(), "in") //[mb,3]
|
||||
.layer("l2", new DenseLayer.Builder().nIn(3).nOut(firstSmaller ? 3 : 1).build(), "in") //[mb,1]
|
||||
.addVertex("ew", new ElementWiseVertex(op), "l1", "l2")
|
||||
.layer("out", new OutputLayer.Builder().nIn(3).nOut(2).lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).build(), "ew")
|
||||
.build();
|
||||
|
||||
ComputationGraph graph = new ComputationGraph(conf);
|
||||
graph.init();
|
||||
|
||||
for (int mb : new int[]{1, 5}) {
|
||||
String msg = (firstSmaller ? "first smaller, " : "second smaller, ") + "mb=" + mb + ", op=" + op;
|
||||
|
||||
log.info("Test: {}", msg);
|
||||
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3);
|
||||
|
||||
INDArray out = graph.outputSingle(in);
|
||||
assertArrayEquals(new long[]{mb, 2}, out.shape());
|
||||
|
||||
INDArray labels = TestUtils.randomOneHot(mb, 2);
|
||||
|
||||
graph.fit(new DataSet(in, labels));
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{in},
|
||||
new INDArray[]{labels});
|
||||
assertTrue(msg, gradOK);
|
||||
TestUtils.testModelSerialization(graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCnnDepthMerge() {
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.nn.conf.graph;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.TestUtils;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
|
@ -34,6 +35,7 @@ import org.nd4j.linalg.activations.impl.ActivationTanH;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Sgd;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
||||
|
@ -42,6 +44,8 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
|
||||
/**
|
||||
* Created by binesh on 6/14/2017.
|
||||
*/
|
||||
|
@ -690,6 +694,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
|
|||
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon);
|
||||
}
|
||||
|
||||
|
||||
private static double mse(INDArray output, INDArray target) {
|
||||
double mse_expect = Transforms.pow(output.sub(target), 2.0).sumNumber().doubleValue()
|
||||
/ (output.columns() * output.rows());
|
||||
|
|
|
@ -27,15 +27,20 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.CustomOp;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldSubOp;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/** An ElementWiseVertex is used to combine the activations of two or more layer in an element-wise manner<br>
|
||||
* For example, the activations may be combined by addition, subtraction or multiplication or by selecting the maximum.
|
||||
* Addition, Average, Product and Max may use an arbitrary number of input arrays. Note that in the case of subtraction, only two inputs may be used.
|
||||
|
@ -80,17 +85,44 @@ public class ElementWiseVertex extends BaseGraphVertex {
|
|||
if (inputs.length == 1)
|
||||
return workspaceMgr.dup(ArrayType.ACTIVATIONS, inputs[0]);
|
||||
|
||||
boolean isBc = false;
|
||||
for(int i=1; i<inputs.length; i++ ){
|
||||
if(!inputs[0].equalShapes(inputs[i])){
|
||||
isBc = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
long[] outShape;
|
||||
if(!isBc){
|
||||
outShape = inputs[0].shape();
|
||||
} else {
|
||||
outShape = Shape.broadcastOutputShape(inputs[0].shape(), inputs[1].shape());
|
||||
for( int i=2; i<inputs.length; i++ ){
|
||||
outShape = Shape.broadcastOutputShape(outShape, inputs[i].shape());
|
||||
}
|
||||
}
|
||||
|
||||
switch (op) {
|
||||
case Add:
|
||||
INDArray sum = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, inputs[0].shape());
|
||||
INDArray sum = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, outShape);
|
||||
if(isBc && !Arrays.equals(outShape, inputs[0].shape())){
|
||||
Nd4j.exec(new BroadcastTo(inputs[0], outShape, sum));
|
||||
} else {
|
||||
sum.assign(inputs[0]);
|
||||
}
|
||||
|
||||
for (int i = 1; i < inputs.length; i++) {
|
||||
sum.addi(inputs[i].castTo(dataType));
|
||||
}
|
||||
return sum;
|
||||
case Average:
|
||||
INDArray average = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, inputs[0].shape());
|
||||
INDArray average = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, outShape);
|
||||
if(isBc && !Arrays.equals(outShape, inputs[0].shape())){
|
||||
Nd4j.exec(new BroadcastTo(inputs[0], outShape, average));
|
||||
} else {
|
||||
average.assign(inputs[0]);
|
||||
}
|
||||
for (int i = 1; i < inputs.length; i++) {
|
||||
average.addi(inputs[i].castTo(dataType));
|
||||
}
|
||||
|
@ -98,15 +130,28 @@ public class ElementWiseVertex extends BaseGraphVertex {
|
|||
case Subtract:
|
||||
if (inputs.length != 2)
|
||||
throw new IllegalArgumentException("ElementWise subtraction only supports 2 inputs");
|
||||
return Nd4j.getExecutioner().exec(new OldSubOp(inputs[0], inputs[1], workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, inputs[0].dataType(), inputs[0].shape())));
|
||||
return Nd4j.exec(new SubOp(inputs, new INDArray[]{workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, inputs[0].dataType(), outShape)}))[0];
|
||||
case Product:
|
||||
INDArray product = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, inputs[0].shape());
|
||||
INDArray product = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, outShape);
|
||||
|
||||
if(isBc && !Arrays.equals(outShape, inputs[0].shape())){
|
||||
Nd4j.exec(new BroadcastTo(inputs[0], outShape, product));
|
||||
} else {
|
||||
product.assign(inputs[0]);
|
||||
}
|
||||
|
||||
for (int i = 1; i < inputs.length; i++) {
|
||||
product.muli(inputs[i].castTo(dataType));
|
||||
}
|
||||
return product;
|
||||
case Max:
|
||||
boolean isBroadcast = false;
|
||||
for(int i=1; i<inputs.length; i++ ){
|
||||
isBroadcast |= !inputs[0].equalShapes(inputs[i]);
|
||||
if(isBroadcast)
|
||||
break;
|
||||
}
|
||||
if(!isBroadcast) {
|
||||
INDArray max = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, inputs[0].dataType(), inputs[0].shape(), inputs[0].ordering());
|
||||
CustomOp op = DynamicCustomOp.builder("mergemax")
|
||||
.addInputs(inputs)
|
||||
|
@ -115,6 +160,19 @@ public class ElementWiseVertex extends BaseGraphVertex {
|
|||
.build();
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
return max;
|
||||
} else {
|
||||
//AB 20190729 mergemax doesn't support broadcast at this point
|
||||
if(inputs.length == 1){
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, inputs[0]);
|
||||
} else {
|
||||
INDArray max = Transforms.max(inputs[0], inputs[1], true);
|
||||
for( int i=2; i<inputs.length; i++ ){
|
||||
max = Transforms.max(max, inputs[i], false);
|
||||
}
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, max);
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unknown op: " + this.op);
|
||||
}
|
||||
|
@ -128,40 +186,119 @@ public class ElementWiseVertex extends BaseGraphVertex {
|
|||
if (nInForwardPass == 1)
|
||||
return new Pair<>(null, new INDArray[] {workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon)});
|
||||
|
||||
boolean broadcastCase = false;
|
||||
for( int i=1; i<nInForwardPass; i++ ){
|
||||
broadcastCase |= !inputs[0].equalShapes(inputs[i]);
|
||||
}
|
||||
|
||||
switch (op) {
|
||||
case Add:
|
||||
//If x=sum_i a_i then dL/da_i = dL/dx * dx/da_i = dL/dx
|
||||
INDArray[] out = new INDArray[nInForwardPass];
|
||||
for (int i = 0; i < nInForwardPass; i++)
|
||||
for (int i = 0; i < nInForwardPass; i++) {
|
||||
if(!broadcastCase) {
|
||||
//Standard case
|
||||
out[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon);
|
||||
} else {
|
||||
//For broadcast case, we need to sum along the broadcast dimensions
|
||||
//So if [mb,3]+[mb,1] -> input 0 backprops epsilon, input 1 backprops epsilon.sum(1,keepDim=true)
|
||||
if(inputs[i].equalShapes(epsilon)){
|
||||
out[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon);
|
||||
} else {
|
||||
int[] bcDim = Shape.getBroadcastDimensions(inputs[i].shape(), epsilon.shape());
|
||||
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)){
|
||||
out[i] = epsilon.sum(true, bcDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return new Pair<>(null, out);
|
||||
case Average:
|
||||
INDArray[] outAverage = new INDArray[nInForwardPass];
|
||||
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)){
|
||||
for (int i = 0; i < nInForwardPass; i++)
|
||||
for (int i = 0; i < nInForwardPass; i++) {
|
||||
if(inputs[i].equalShapes(epsilon)){
|
||||
outAverage[i] = epsilon.div(nInForwardPass);
|
||||
} else {
|
||||
int[] bcDim = Shape.getBroadcastDimensions(inputs[i].shape(), epsilon.shape());
|
||||
outAverage[i] = epsilon.div(nInForwardPass).sum(true, bcDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
return new Pair<>(null, outAverage);
|
||||
case Subtract:
|
||||
INDArray[] out2 = new INDArray[2];
|
||||
if(!broadcastCase){
|
||||
out2[0] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon);
|
||||
out2[1] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon).negi();
|
||||
} else {
|
||||
if(inputs[0].equalShapes(epsilon)){
|
||||
//Second input is smaller/broadcast
|
||||
out2[0] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon);
|
||||
int[] bcDim = Shape.getBroadcastDimensions(inputs[1].shape(), epsilon.shape());
|
||||
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)) {
|
||||
out2[1] = epsilon.sum(true, bcDim).negi();
|
||||
}
|
||||
} else {
|
||||
//First input is smaller/broadcast
|
||||
int[] bcDim = Shape.getBroadcastDimensions(inputs[0].shape(), epsilon.shape());
|
||||
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)) {
|
||||
out2[0] = epsilon.sum(true, bcDim);
|
||||
}
|
||||
out2[1] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon).negi();
|
||||
}
|
||||
}
|
||||
return new Pair<>(null, out2);
|
||||
case Product:
|
||||
INDArray[] out_product = new INDArray[nInForwardPass];
|
||||
INDArray[] inBc = inputs;
|
||||
if(broadcastCase){
|
||||
inBc = new INDArray[inputs.length];
|
||||
for( int i=0; i<inputs.length; i++ ){
|
||||
if(inputs[i].equalShapes(epsilon)){
|
||||
inBc[i] = inputs[i];
|
||||
} else {
|
||||
inBc[i] = epsilon.ulike();
|
||||
Nd4j.exec(new BroadcastTo(inputs[i], epsilon.shape(), inBc[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < nInForwardPass; i++) {
|
||||
out_product[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon);
|
||||
for (int j = 0; j < nInForwardPass; ++j) {
|
||||
if (i != j)
|
||||
out_product[i].muli(inputs[j]);
|
||||
out_product[i].muli(inBc[j]);
|
||||
}
|
||||
|
||||
if(!inputs[i].equalShapes(epsilon)){
|
||||
int[] bcDim = Shape.getBroadcastDimensions(inputs[i].shape(), epsilon.shape());
|
||||
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)) {
|
||||
out_product[i] = out_product[i].sum(true, bcDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
return new Pair<>(null, out_product);
|
||||
case Max:
|
||||
INDArray[] outMax = new INDArray[nInForwardPass];
|
||||
INDArray maxIndices = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, DataType.INT, epsilon.shape(), epsilon.ordering());
|
||||
|
||||
INDArray[] bcIn = inputs;
|
||||
if(broadcastCase){
|
||||
//Broadcast to right shape...
|
||||
bcIn = new INDArray[inputs.length];
|
||||
for( int i=0; i<inputs.length; i++ ){
|
||||
if(inputs[i].equalShapes(epsilon)){
|
||||
bcIn[i] = inputs[i];
|
||||
} else {
|
||||
bcIn[i] = epsilon.ulike();
|
||||
Nd4j.exec(new BroadcastTo(inputs[i], epsilon.shape(), bcIn[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CustomOp op = DynamicCustomOp.builder("mergemaxindex")
|
||||
.addInputs(inputs)
|
||||
.addInputs(bcIn)
|
||||
.addOutputs(maxIndices)
|
||||
.callInplace(false)
|
||||
.build();
|
||||
|
@ -172,7 +309,17 @@ public class ElementWiseVertex extends BaseGraphVertex {
|
|||
//generate a mask with 1s and 0s in the right places and muli with epsilon
|
||||
MatchConditionTransform nd4jop = new MatchConditionTransform(maxIndices, outMax[i], Conditions.equals(i));
|
||||
Nd4j.getExecutioner().exec(nd4jop);
|
||||
outMax[i] = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, outMax[i].castTo(Nd4j.defaultFloatingPointType())).muli(epsilon);
|
||||
if(broadcastCase && !epsilon.equalShapes(inputs[i])){
|
||||
//Broadcast for ths input
|
||||
outMax[i] = outMax[i].castTo(epsilon.dataType()).mul(epsilon);
|
||||
int[] bcDim = Shape.getBroadcastDimensions(inputs[i].shape(), epsilon.shape());
|
||||
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)) {
|
||||
outMax[i] = outMax[i].sum(true, bcDim);
|
||||
}
|
||||
} else {
|
||||
//Standard case
|
||||
outMax[i] = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, outMax[i].castTo(epsilon.dataType()).muli(epsilon));
|
||||
}
|
||||
}
|
||||
return new Pair<>(null, outMax);
|
||||
default:
|
||||
|
|
|
@ -2476,6 +2476,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
// length/data.length can be different in case of Threshold conversion
|
||||
if(isEmpty() || isS())
|
||||
return false;
|
||||
|
||||
return Shape.offset(jvmShapeInfo.javaShapeInformation) > 0
|
||||
|| (length() < data().length() && data.dataType() != DataType.INT)
|
||||
|| data().originalDataBuffer() != null;
|
||||
|
@ -4577,7 +4578,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return ret;
|
||||
} else {
|
||||
INDArray ret = this.dup(order);
|
||||
return ret.reshape(order, shape);
|
||||
return Nd4j.create(ret.data(), shape);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -44,6 +44,10 @@ public class Max extends BaseDynamicTransformOp {
|
|||
super(sameDiff, args, inPlace);
|
||||
}
|
||||
|
||||
public Max( INDArray first, INDArray second, INDArray out){
|
||||
super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out});
|
||||
}
|
||||
|
||||
public Max( INDArray[] inputs, INDArray[] outputs) {
|
||||
super(inputs, outputs);
|
||||
}
|
||||
|
|
|
@ -44,6 +44,10 @@ public class Min extends BaseDynamicTransformOp {
|
|||
super(sameDiff, args, inPlace);
|
||||
}
|
||||
|
||||
public Min( INDArray first, INDArray second, INDArray out){
|
||||
super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out});
|
||||
}
|
||||
|
||||
public Min( INDArray[] inputs, INDArray[] outputs) {
|
||||
super(inputs, outputs);
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Calculate the absolute minimum over a vector
|
||||
* Calculate the maximum value between two arrays in an elementwise fashion, broadcasting if required
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
|
|
|
@ -26,7 +26,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Calculate the absolute minimum over a vector
|
||||
* Calculate the minimum value between two arrays in an elementwise fashion, broadcasting if required
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
|
|
|
@ -538,7 +538,7 @@ public class Nd4j {
|
|||
public static INDArray create(int[] sliceShape, float[]... arrays) {
|
||||
//TODO: Remove duplicate code.
|
||||
int slices = arrays.length;
|
||||
INDArray ret = Nd4j.create(ArrayUtil.combine(new int[] {slices}, sliceShape));
|
||||
INDArray ret = Nd4j.createUninitialized(DataType.FLOAT, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape)));
|
||||
for (int i = 0; i < ret.slices(); i++)
|
||||
ret.putSlice(i, Nd4j.create(arrays[i]).reshape(ArrayUtil.toLongArray(sliceShape)));
|
||||
return ret;
|
||||
|
@ -572,7 +572,7 @@ public class Nd4j {
|
|||
*/
|
||||
public static INDArray create(int[] sliceShape, double[]... arrays) {
|
||||
int slices = arrays.length;
|
||||
INDArray ret = Nd4j.create(ArrayUtil.combine(new int[] {slices}, sliceShape));
|
||||
INDArray ret = Nd4j.createUninitialized(DataType.DOUBLE, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape)));
|
||||
for (int i = 0; i < ret.slices(); i++)
|
||||
ret.putSlice(i, Nd4j.create(arrays[i]).reshape(ArrayUtil.toLongArray(sliceShape)));
|
||||
return ret;
|
||||
|
@ -3984,6 +3984,7 @@ public class Nd4j {
|
|||
* @return the created ndarray.
|
||||
*/
|
||||
public static INDArray create(int[] data, long[] shape, DataType type) {
|
||||
checkShapeValues(data.length, shape);
|
||||
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
|
@ -3991,6 +3992,7 @@ public class Nd4j {
|
|||
* See {@link #create(int[], long[], DataType)}
|
||||
*/
|
||||
public static INDArray create(long[] data, long[] shape, DataType type) {
|
||||
checkShapeValues(data.length, shape);
|
||||
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
|
@ -3998,6 +4000,7 @@ public class Nd4j {
|
|||
* See {@link #create(int[], long[], DataType)}
|
||||
*/
|
||||
public static INDArray create(double[] data, long[] shape, DataType type) {
|
||||
checkShapeValues(data.length, shape);
|
||||
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
|
@ -4005,6 +4008,7 @@ public class Nd4j {
|
|||
* See {@link #create(int[], long[], DataType)}
|
||||
*/
|
||||
public static INDArray create(float[] data, long[] shape, DataType type) {
|
||||
checkShapeValues(data.length, shape);
|
||||
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
|
@ -4012,6 +4016,7 @@ public class Nd4j {
|
|||
* See {@link #create(int[], long[], DataType)}
|
||||
*/
|
||||
public static INDArray create(short[] data, long[] shape, DataType type) {
|
||||
checkShapeValues(data.length, shape);
|
||||
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
|
@ -4019,6 +4024,7 @@ public class Nd4j {
|
|||
* See {@link #create(int[], long[], DataType)}
|
||||
*/
|
||||
public static INDArray create(byte[] data, long[] shape, DataType type) {
|
||||
checkShapeValues(data.length, shape);
|
||||
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
|
@ -4026,6 +4032,7 @@ public class Nd4j {
|
|||
* See {@link #create(int[], long[], DataType)}
|
||||
*/
|
||||
public static INDArray create(boolean[] data, long[] shape, DataType type) {
|
||||
checkShapeValues(data.length, shape);
|
||||
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
|
@ -5165,17 +5172,17 @@ public class Nd4j {
|
|||
protected static void checkShapeValues(int length, int... shape) {
|
||||
checkShapeValues(shape);
|
||||
|
||||
if (ArrayUtil.prodLong(shape) > length)
|
||||
if (ArrayUtil.prodLong(shape) != length && !(length == 1 && shape.length == 0))
|
||||
throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape)
|
||||
+ " doesn't match data length: " + length);
|
||||
+ " doesn't match data length: " + length + " - prod(shape) must equal the number of values provided");
|
||||
}
|
||||
|
||||
protected static void checkShapeValues(int length, long... shape) {
|
||||
checkShapeValues(shape);
|
||||
|
||||
if (ArrayUtil.prodLong(shape) > length)
|
||||
if (ArrayUtil.prodLong(shape) != length && !(length == 1 && shape.length == 0))
|
||||
throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape)
|
||||
+ " doesn't match data length: " + length);
|
||||
+ " doesn't match data length: " + length + " - prod(shape) must equal the number of values provided");
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -45,9 +45,11 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
|
|||
import org.nd4j.linalg.api.ops.impl.transforms.same.*;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.inverse.InvertMatrix;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
|
@ -858,11 +860,11 @@ public class Transforms {
|
|||
* @return
|
||||
*/
|
||||
public static INDArray max(INDArray first, INDArray second, boolean dup) {
|
||||
INDArray result = first;
|
||||
if (dup) {
|
||||
result = first.ulike();
|
||||
}
|
||||
return exec(new OldMax(first, second, result));
|
||||
long[] outShape = broadcastResultShape(first, second); //Also validates
|
||||
Preconditions.checkState(dup || Arrays.equals(outShape, first.shape()), "Cannot do inplace max operation when first input is not equal to result shape (%ndShape vs. result %s)",
|
||||
first, outShape);
|
||||
INDArray out = dup ? Nd4j.create(first.dataType(), outShape) : first;
|
||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(first, second, out))[0];
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -908,10 +910,11 @@ public class Transforms {
|
|||
* @return
|
||||
*/
|
||||
public static INDArray min(INDArray first, INDArray second, boolean dup) {
|
||||
if (dup) {
|
||||
first = first.dup();
|
||||
}
|
||||
return exec(new OldMin(second, first, first));
|
||||
long[] outShape = broadcastResultShape(first, second); //Also validates
|
||||
Preconditions.checkState(dup || Arrays.equals(outShape, first.shape()), "Cannot do inplace min operation when first input is not equal to result shape (%ndShape vs. result %s)",
|
||||
first, outShape);
|
||||
INDArray out = dup ? Nd4j.create(first.dataType(), outShape) : first;
|
||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(first, second, out))[0];
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1179,4 +1182,15 @@ public class Transforms {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
protected static long[] broadcastResultShape(INDArray first, INDArray second){
|
||||
if(first.equalShapes(second)){
|
||||
return first.shape();
|
||||
} else if(Shape.areShapesBroadcastable(first.shape(), second.shape())){
|
||||
return Shape.broadcastOutputShape(first.shape(), second.shape());
|
||||
} else {
|
||||
throw new IllegalStateException("Array shapes are not broadcastable: " + Arrays.toString(first.shape()) +
|
||||
" vs. " + Arrays.toString(second.shape()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2699,6 +2699,21 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
assertEquals(expected, actual);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBroadcastDiv2(){
|
||||
INDArray arr = Nd4j.ones(DataType.DOUBLE, 1, 64, 125, 125).muli(2);
|
||||
INDArray vec = Nd4j.ones(DataType.DOUBLE, 64).muli(2);
|
||||
|
||||
INDArray exp = Nd4j.ones(DataType.DOUBLE, 1, 64, 125, 125);
|
||||
INDArray out = arr.like();
|
||||
|
||||
for( int i=0; i<10; i++ ) {
|
||||
out.assign(0.0);
|
||||
Nd4j.getExecutioner().exec(new BroadcastDivOp(arr, vec, out, 1));
|
||||
assertEquals(exp, out);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testBroadcastMult() {
|
||||
|
@ -7417,7 +7432,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
|
||||
INDArray arr1a = Nd4j.create(new long[]{2,3}, 'c').get(NDArrayIndex.all(), NDArrayIndex.interval(0,2));
|
||||
INDArray arr3 = arr1a.reshape('c', false, 4,1);
|
||||
assertFalse(arr3.isView()); //Should be copy
|
||||
boolean isView = arr3.isView();
|
||||
assertFalse(isView); //Should be copy
|
||||
|
||||
try{
|
||||
INDArray arr4 = arr1a.reshape('c', true, 4,1);
|
||||
|
@ -7861,6 +7877,54 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
final INDArray arr2 = arr1.reshape(3,1);
|
||||
assertEquals("Incorrect type!", DataType.FLOAT, arr1.mmul(arr2).dataType());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testCreateDtypes() {
|
||||
int[] sliceShape = new int[] {9};
|
||||
float[] arrays = new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f};
|
||||
double [] arrays_double = new double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0};
|
||||
|
||||
INDArray x = Nd4j.create( sliceShape, arrays, arrays );
|
||||
assertEquals(DataType.FLOAT, x.dataType());
|
||||
|
||||
INDArray xd = Nd4j.create( sliceShape, arrays_double, arrays_double );
|
||||
assertEquals(DataType.DOUBLE, xd.dataType());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testCreateShapeValidation(){
|
||||
try {
|
||||
Nd4j.create(new double[]{1, 2, 3}, new int[]{1, 1});
|
||||
fail();
|
||||
} catch (Exception t){
|
||||
assertTrue(t.getMessage().contains("length"));
|
||||
}
|
||||
|
||||
try {
|
||||
Nd4j.create(new float[]{1, 2, 3}, new int[]{1, 1});
|
||||
fail();
|
||||
} catch (Exception t){
|
||||
assertTrue(t.getMessage().contains("length"));
|
||||
}
|
||||
|
||||
try {
|
||||
Nd4j.create(new byte[]{1, 2, 3}, new long[]{1, 1}, DataType.BYTE);
|
||||
fail();
|
||||
} catch (Exception t){
|
||||
assertTrue(t.getMessage().contains("length"));
|
||||
}
|
||||
|
||||
try {
|
||||
Nd4j.create(new double[]{1, 2, 3}, new int[]{1, 1}, 'c');
|
||||
fail();
|
||||
} catch (Exception t){
|
||||
assertTrue(t.getMessage().contains("length"));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////
|
||||
protected static void fillJvmArray3D(float[][][] arr) {
|
||||
int cnt = 1;
|
||||
|
|
|
@ -2601,7 +2601,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
}
|
||||
|
||||
Pointer.memcpy(pointer, oldPointer, this.length() * getElementSize());
|
||||
//this.underlyingLength = length;
|
||||
this.underlyingLength = length;
|
||||
this.length = length;
|
||||
return this;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue