Various ND4J/DL4J fixes (#90)

* Deprecate Old*Op instances

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8063 #8054 Broadcast exceptions + cleanup inplace ops

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Remove bad test condition

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7993 Fix shape function issue in crop_and_resize op

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* DL4J SameDiff lambda layer fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8029 Fix for pnorm backprop math

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-01 19:30:58 +10:00 committed by AlexDBlack
parent fbe120031d
commit e18e2dc014
23 changed files with 188 additions and 188 deletions

View File

@ -290,18 +290,16 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
int nOut = 2;
//1 example, TS length 3
INDArray mask1 = Nd4j.create(new double[] {1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0}, new int[] {1, nOut, 3}, 'f');
INDArray mask1 = Nd4j.create(new double[] {1, 0, 0, 1, 0, 1}, new int[] {1, nOut, 3}, 'f');
//1 example, TS length 1
INDArray mask2 = Nd4j.create(new double[] {1, 1, 0, 1}, new int[] {1, nOut, 1}, 'f');
INDArray mask2 = Nd4j.create(new double[] {1, 1}, new int[] {1, nOut, 1}, 'f');
//3 examples, TS length 3
INDArray mask3 = Nd4j.create(new double[] {
//With fortran order: dimension 0 (example) changes quickest, followed by dimension 1 (value within time
// step) followed by time index (least frequently)
1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0,
0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1,
1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0}, new int[] {3, nOut, 3}, 'f');
1, 0, 1, 0, 1, 1,
0, 1, 1, 1, 1, 0,
1, 1, 1, 0, 0, 1,}, new int[] {3, nOut, 3}, 'f');
INDArray[] labelMasks = new INDArray[] {mask1, mask2, mask3};
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(),

View File

@ -127,7 +127,7 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation
}
if(context == null || true){
if(context == null ){
context = Nd4j.getExecutioner().buildContext();
context.setIArguments(kernel[0], kernel[1],
strides[0], strides[1],

View File

@ -159,7 +159,14 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
g.gradientForVariable().put(s, dl4jGrad);
}
dLdIn = sameDiff.grad(INPUT_KEY).getArr();
SDVariable v = sameDiff.grad(INPUT_KEY);
dLdIn = v.getArr();
if(dLdIn == null && fn.getGradPlaceholderName().equals(v.getVarName())){
//Edge case with lambda layers like identity: SameDiff doesn't store the placeholders
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
dLdIn = epsilon;
}
}
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere

View File

@ -56,7 +56,7 @@ namespace nd4j {
}
DECLARE_SHAPE_FN(crop_and_resize) {
auto in = inputShape->at(1);
auto in = inputShape->at(0);
Nd4jLong outputShape[4];

View File

@ -2014,7 +2014,7 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
for (Nd4jLong kh = hstart; kh < hend; kh += iStep2)
for (Nd4jLong kw = wstart; kw < wend; kw += iStep3)
pgI[kh + kw] += valO * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kh + kw]), extraParam0 - 1.f);
pgI[kh + kw] += valO * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kh + kw]), extraParam0 - 1.f) * nd4j::math::nd4j_sgn<T,T>(pIn[kh + kw]);
}
}
}

View File

@ -16,22 +16,15 @@
package org.nd4j.autodiff.samediff;
import static org.nd4j.autodiff.util.TrainingUtils.getSingleOutput;
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.google.common.primitives.Ints;
import com.google.flatbuffers.FlatBufferBuilder;
import com.rits.cloning.Cloner;
import com.rits.cloning.IFastCloner;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.output.CloseShieldOutputStream;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
@ -49,7 +42,6 @@ import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.graph.*;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -59,7 +51,6 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.If;
import org.nd4j.linalg.api.ops.impl.controlflow.While;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
@ -76,7 +67,6 @@ import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JException;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
@ -89,11 +79,11 @@ import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.nd4j.linalg.util.ND4JFileUtils;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ConstantInitScheme;
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.GraphDef;
import java.io.*;
import java.lang.reflect.Method;
@ -101,10 +91,11 @@ import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
import org.tensorflow.framework.GraphDef;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static org.nd4j.autodiff.util.TrainingUtils.getSingleOutput;
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
/**
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality.

View File

@ -3692,7 +3692,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
*/
@Override
public INDArray divi(INDArray other) {
validateNumericalArray("divi", false);
return divi(other, this);
}
@ -3706,30 +3705,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
@Override
public INDArray divi(INDArray other, INDArray result) {
validateNumericalArray("divi", false);
if (other.isScalar()) {
return divi(other.getDouble(0), result);
}
if (isScalar()) {
return other.rdivi(getDouble(0), result);
}
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
val outShape = Shape.broadcastOutputShape(this.shape(), other.shape());
Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape()));
Nd4j.exec(new DivOp(new INDArray[]{this, other}, new INDArray[]{result}));
return result;
} else if(!Shape.shapeEquals(this.shape(),other.shape())) {
int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
Nd4j.getExecutioner().exec(new BroadcastDivOp(this,other,result,broadcastDimensions));
return result;
}
LinAlgExceptions.assertSameShape(other, result);
Nd4j.getExecutioner().exec(new OldDivOp(this, other, result));
Shape.assertBroadcastable("divi", this, other, result);
Nd4j.exec(new DivOp(this, other, result));
return result;
}
@ -3741,7 +3718,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
*/
@Override
public INDArray muli(INDArray other) {
validateNumericalArray("muli", false);
return muli(other, this);
}
@ -3755,29 +3731,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
@Override
public INDArray muli(INDArray other, INDArray result) {
validateNumericalArray("muli", false);
if (other.isScalar()) {
return muli(other.getDouble(0), result);
}
if (isScalar()) {
return other.muli(getDouble(0), result);
}
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
val outShape = Shape.broadcastOutputShape(this.shape(), other.shape());
Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape()));
Nd4j.exec(new MulOp(new INDArray[]{this, other}, new INDArray[]{result}));
return result;
} else if(!Shape.shapeEquals(this.shape(),other.shape())) {
int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
Nd4j.getExecutioner().exec(new BroadcastMulOp(this,other,result,broadcastDimensions));
return result;
}
LinAlgExceptions.assertSameShape(other, result);
Nd4j.getExecutioner().exec(new OldMulOp(this, other, result));
Shape.assertBroadcastable("muli", this, other, result);
Nd4j.exec(new MulOp(this, other, result));
return result;
}
@ -3802,31 +3757,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
@Override
public INDArray subi(INDArray other, INDArray result) {
validateNumericalArray("subi", false);
if (other.isScalar()) {
return subi(other.getDouble(0), result);
}
if (isScalar()) {
return other.rsubi(getDouble(0), result);
}
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
val outShape = Shape.broadcastOutputShape(this.shape(), other.shape());
Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape()));
Nd4j.exec(new SubOp(new INDArray[]{this, other}, new INDArray[]{result}));
return result;
} else if(!Shape.shapeEquals(this.shape(),other.shape())) {
int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
Nd4j.getExecutioner().exec(new BroadcastSubOp(this,other,result,broadcastDimensions));
return result;
}
LinAlgExceptions.assertSameShape(other, result);
Nd4j.getExecutioner().exec(new OldSubOp(this, other,result));
Shape.assertBroadcastable("subi", this, other, result);
Nd4j.exec(new SubOp(this, other, result));
return result;
}
@ -3851,33 +3783,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
@Override
public INDArray addi(INDArray other, INDArray result) {
validateNumericalArray("addi", false);
if (other.isScalar()) {
return this.addi(other.getDouble(0), result);
}
if (isScalar()) {
return other.addi(getDouble(0), result);
}
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
val outShape = Shape.broadcastOutputShape(this.shape(), other.shape());
Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape()));
Nd4j.exec(new AddOp(new INDArray[]{this, other}, new INDArray[]{result}));
Shape.assertBroadcastable("addi", this, other, result);
Nd4j.exec(new AddOp(this, other, result));
return result;
} else if(!Shape.shapeEquals(this.shape(),other.shape())) {
int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
result = Nd4j.createUninitialized(this.dataType(), Shape.broadcastOutputShape(this.shape(),other.shape()));
Nd4j.getExecutioner().exec(new BroadcastAddOp(this,other,result,broadcastDimensions));
return result;
} else {
LinAlgExceptions.assertSameShape(this, other, result);
Nd4j.getExecutioner().exec(new OldAddOp(this, other, result));
return result;
}
}
/**
@ -3954,7 +3862,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
@Override
public INDArray rdivi(INDArray other, INDArray result) {
validateNumericalArray("rdivi", false);
return other.divi(this, result);
Shape.assertBroadcastable("rdivi", this, other, result);
Nd4j.exec(new RDivOp(this, other, result));
return result;
}
/**
@ -4003,33 +3913,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
@Override
public INDArray rsubi(INDArray other, INDArray result) {
validateNumericalArray("rsubi", false);
if (other.isScalar()) {
return this.rsubi(other.getDouble(0), result);
}
if (isScalar()) {
return other.rsubi(getDouble(0), result);
}
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
val outShape = Shape.broadcastOutputShape(this.shape(), other.shape());
Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape()));
Nd4j.exec(new RSubOp(new INDArray[]{this, other}, new INDArray[]{result}));
Shape.assertBroadcastable("rsubi", this, other, result);
Nd4j.exec(new RSubOp(this, other, result));
return result;
} else if(!Shape.shapeEquals(this.shape(),other.shape())) {
int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
result = Nd4j.createUninitialized(this.dataType(), Shape.broadcastOutputShape(this.shape(),other.shape()));
Nd4j.getExecutioner().exec(new BroadcastRSubOp(this,other,result,broadcastDimensions));
return result;
} else {
LinAlgExceptions.assertSameShape(this, other, result);
Nd4j.getExecutioner().exec(new OldRSubOp(this, other, result));
return result;
}
}
/**
@ -6796,6 +6682,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
throw new IllegalStateException("Cannot perform operation " + opName + " on empty array with datatype " + dataType());
}
@Override
public boolean closeable() {
if (released || isAttached())

View File

@ -38,6 +38,10 @@ public class AddOp extends BaseDynamicTransformOp {
super(sameDiff, args, inPlace);
}
public AddOp(INDArray first, INDArray second, INDArray result){
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
}
public AddOp(INDArray[] inputs, INDArray[] outputs) {
super(inputs, outputs);
}

View File

@ -37,6 +37,10 @@ public class DivOp extends BaseDynamicTransformOp {
super(sameDiff, args, inPlace);
}
public DivOp(INDArray first, INDArray second, INDArray result){
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
}
public DivOp( INDArray[] inputs, INDArray[] outputs) {
super(inputs, outputs);
}

View File

@ -37,6 +37,10 @@ public class MulOp extends BaseDynamicTransformOp {
super(sameDiff, args, inPlace);
}
public MulOp(INDArray first, INDArray second, INDArray result){
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
}
public MulOp( INDArray[] inputs, INDArray[] outputs) {
super(inputs, outputs);
}

View File

@ -28,10 +28,9 @@ import java.util.ArrayList;
import java.util.List;
/**
* Add operation for two operands
*
* @author Adam Gibson
* @deprecated Use {@link AddOp}
*/
@Deprecated
public class OldAddOp extends BaseTransformAnyOp {
public OldAddOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);

View File

@ -28,10 +28,9 @@ import java.util.ArrayList;
import java.util.List;
/**
* Division operation
*
* @author Adam Gibson
* @deprecated Use {@link DivOp}
*/
@Deprecated
public class OldDivOp extends BaseTransformAnyOp {
public OldDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);

View File

@ -28,10 +28,9 @@ import java.util.ArrayList;
import java.util.List;
/**
* Multiplication operation
*
* @author Adam Gibson
* @deprecated Use {@link MulOp}
*/
@Deprecated
public class OldMulOp extends BaseTransformAnyOp {
public OldMulOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);

View File

@ -28,10 +28,9 @@ import java.util.ArrayList;
import java.util.List;
/**
* OldReverse Division operation
*
* @author Adam Gibson
* @deprecated Use {@link RDivOp}
*/
@Deprecated
public class OldRDivOp extends BaseTransformAnyOp {
public OldRDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);

View File

@ -26,10 +26,9 @@ import java.util.ArrayList;
import java.util.List;
/**
* Division operation
*
* @author Adam Gibson
* @deprecated Use {@link RSubOp}
*/
@Deprecated
public class OldRSubOp extends BaseTransformAnyOp {
public OldRSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);

View File

@ -28,10 +28,9 @@ import java.util.ArrayList;
import java.util.List;
/**
* Division operation
*
* @author Adam Gibson
* @deprecated Use {@link SubOp}
*/
@Deprecated
public class OldSubOp extends BaseTransformAnyOp {
public OldSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);

View File

@ -38,6 +38,10 @@ public class RDivOp extends BaseDynamicTransformOp {
super(sameDiff, args, inPlace);
}
public RDivOp(INDArray first, INDArray second, INDArray result){
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
}
public RDivOp( INDArray[] inputs, INDArray[] outputs) {
super(inputs, outputs);
}

View File

@ -45,6 +45,10 @@ public class RSubOp extends BaseDynamicTransformOp {
this(sameDiff, new SDVariable[]{i_v1, i_v2}, inPlace);
}
public RSubOp(INDArray first, INDArray second, INDArray result){
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
}
public RSubOp() {}
@Override

View File

@ -37,6 +37,10 @@ public class SubOp extends BaseDynamicTransformOp {
super(sameDiff, args, inPlace);
}
public SubOp(INDArray first, INDArray second, INDArray result){
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
}
public SubOp( INDArray[] inputs, INDArray[] outputs) {
super(inputs, outputs);
}

View File

@ -251,6 +251,41 @@ public class Shape {
return false;
}
/**
* Assert that the broadcast operation {@code result = first.op(second)} is valid, given the
* shapes of first, second, and result.<br>
* Throws an exception otherwise
*
* @param op Name of the operation
* @param first First array
* @param second Second array
* @param result Result arrray.
*/
public static void assertBroadcastable(String op, INDArray first, INDArray second, INDArray result){
long[] fShape = first.shape();
long[] sShape = second.shape();
Preconditions.checkState(Shape.areShapesBroadcastable(fShape, sShape),
"Cannot perform operation \"%s\" - shapes are not equal and are not broadcastable." +
"first.shape=%s, second.shape=%s", op, fShape, sShape);
long[] outShape = Shape.broadcastOutputShape(fShape, sShape);
if (!Arrays.equals(outShape, result.shape())) {
//Two cases
// 1. x.addi(y)
// 2. x.addi(y, z)
String extra = "";
if(first == result){
extra = ".\nIn-place operations like x." + op + "(y) can only be performed when x and y have the same shape," +
" or x and y are broadcastable with x.shape() == broadcastShape(x,y)";
}
throw new IllegalStateException("Cannot perform in-place operation \"" + op + "\": result array shape does" +
" not match the broadcast operation output shape: " + Arrays.toString(fShape) + "." + op + "(" +
Arrays.toString(sShape) + ") != " + Arrays.toString(result.shape()) + extra);
}
}
public static long[] broadcastOutputShape(long[] left,long[] right) {
if (containsZeros(left))
return left;

View File

@ -80,6 +80,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse;
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
@ -6222,7 +6223,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(expected, b.rdiv(2));
assertEquals(expected2, d.rdivColumnVector(c));
assertEquals(expected, b.rdiv(Nd4j.scalar(2)));
assertEquals(expected, b.rdiv(Nd4j.scalar(2.0)));
assertEquals(expected, b.rdivColumnVector(Nd4j.scalar(2)));
}
@ -7958,7 +7959,11 @@ public class Nd4jTestsC extends BaseNd4jTest {
c.addOutputArgument(out);
Nd4j.getExecutioner().exec(c);
assertEquals(Nd4j.createFromArray(1f, 3f, 4f), out);
List<LongShapeDescriptor> l = c.calculateOutputShape();
System.out.println(Arrays.toString(l.get(0).getShape()));
//from [4,4,3] to [2,4,6] then crop to [2,4,5]
}

View File

@ -30,6 +30,7 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* @author raver119@gmail.com
@ -122,42 +123,42 @@ public class BasicBroadcastTests extends BaseNd4jTest {
assertEquals(e, z);
}
@Test(expected = IllegalArgumentException.class)
@Test(expected = IllegalStateException.class)
public void basicBroadcastFailureTest_1() {
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
val z = x.subi(y);
}
@Test(expected = IllegalArgumentException.class)
@Test(expected = IllegalStateException.class)
public void basicBroadcastFailureTest_2() {
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
val z = x.divi(y);
}
@Test(expected = IllegalArgumentException.class)
@Test(expected = IllegalStateException.class)
public void basicBroadcastFailureTest_3() {
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
val z = x.muli(y);
}
@Test(expected = IllegalArgumentException.class)
@Test(expected = IllegalStateException.class)
public void basicBroadcastFailureTest_4() {
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
val z = x.addi(y);
}
@Test(expected = IllegalArgumentException.class)
@Test(expected = IllegalStateException.class)
public void basicBroadcastFailureTest_5() {
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
val z = x.rsubi(y);
}
@Test(expected = IllegalArgumentException.class)
@Test(expected = IllegalStateException.class)
public void basicBroadcastFailureTest_6() {
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
@ -206,7 +207,7 @@ public class BasicBroadcastTests extends BaseNd4jTest {
assertEquals(y, z);
}
@Test(expected = IllegalArgumentException.class)
@Test(expected = IllegalStateException.class)
public void emptyBroadcastTest_2() {
val x = Nd4j.create(DataType.FLOAT, 1, 2);
val y = Nd4j.create(DataType.FLOAT, 0, 2);
@ -226,6 +227,67 @@ public class BasicBroadcastTests extends BaseNd4jTest {
assertEquals(y, z);
}
@Test
public void testValidInvalidBroadcast(){
INDArray x = Nd4j.rand(3,1);
INDArray y = Nd4j.create(3, 4);
x.add(y);
y.addi(x);
try {
x.addi(y);
} catch (Exception e){
String s = e.getMessage();
assertTrue(s, s.contains("broadcast") && s.contains("shape"));
}
x.sub(y);
y.subi(x);
try {
x.subi(y);
} catch (Exception e){
String s = e.getMessage();
assertTrue(s, s.contains("broadcast") && s.contains("shape"));
}
x.mul(y);
y.muli(x);
try {
x.muli(y);
} catch (Exception e){
String s = e.getMessage();
assertTrue(s, s.contains("broadcast") && s.contains("shape"));
}
x.div(y);
y.divi(x);
try {
x.divi(y);
} catch (Exception e){
String s = e.getMessage();
assertTrue(s, s.contains("broadcast") && s.contains("shape"));
}
x.rsub(y);
y.rsubi(x);
try {
x.rsubi(y);
} catch (Exception e){
String s = e.getMessage();
assertTrue(s, s.contains("broadcast") && s.contains("shape"));
}
x.rdiv(y);
y.rdivi(x);
try {
x.rdivi(y);
} catch (Exception e){
String s = e.getMessage();
assertTrue(s, s.contains("broadcast") && s.contains("shape"));
}
}
@Override
public char ordering() {
return 'c';

View File

@ -7,8 +7,7 @@ import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
public class CropAndResizeDataSetPreProcessorTest {
@ -93,10 +92,7 @@ public class CropAndResizeDataSetPreProcessorTest {
// Assert
INDArray results = ds.getFeatures();
long[] shape = results.shape();
assertEquals(1, shape[0]);
assertEquals(4, shape[1]);
assertEquals(3, shape[2]);
assertEquals(3, shape[3]);
assertArrayEquals(new long[]{1, 4, 3, 3}, shape);
// Test a few values
assertEquals(55.0, results.getDouble(0, 0, 0, 0), 0.0);