Various ND4J/DL4J fixes (#90)
* Deprecate Old*Op instances Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8063 #8054 Broadcast exceptions + cleanup inplace ops Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove bad test condition Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7993 Fix shape function issue in crop_and_resize op Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J SameDiff lambda layer fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8029 Fix for pnorm backprop math Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
fbe120031d
commit
e18e2dc014
|
@ -290,18 +290,16 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
|
|||
int nOut = 2;
|
||||
|
||||
//1 example, TS length 3
|
||||
INDArray mask1 = Nd4j.create(new double[] {1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0}, new int[] {1, nOut, 3}, 'f');
|
||||
INDArray mask1 = Nd4j.create(new double[] {1, 0, 0, 1, 0, 1}, new int[] {1, nOut, 3}, 'f');
|
||||
//1 example, TS length 1
|
||||
INDArray mask2 = Nd4j.create(new double[] {1, 1, 0, 1}, new int[] {1, nOut, 1}, 'f');
|
||||
INDArray mask2 = Nd4j.create(new double[] {1, 1}, new int[] {1, nOut, 1}, 'f');
|
||||
//3 examples, TS length 3
|
||||
INDArray mask3 = Nd4j.create(new double[] {
|
||||
//With fortran order: dimension 0 (example) changes quickest, followed by dimension 1 (value within time
|
||||
// step) followed by time index (least frequently)
|
||||
1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0,
|
||||
|
||||
0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1,
|
||||
|
||||
1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0}, new int[] {3, nOut, 3}, 'f');
|
||||
1, 0, 1, 0, 1, 1,
|
||||
0, 1, 1, 1, 1, 0,
|
||||
1, 1, 1, 0, 0, 1,}, new int[] {3, nOut, 3}, 'f');
|
||||
INDArray[] labelMasks = new INDArray[] {mask1, mask2, mask3};
|
||||
|
||||
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(),
|
||||
|
|
|
@ -127,7 +127,7 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
|||
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation
|
||||
}
|
||||
|
||||
if(context == null || true){
|
||||
if(context == null ){
|
||||
context = Nd4j.getExecutioner().buildContext();
|
||||
context.setIArguments(kernel[0], kernel[1],
|
||||
strides[0], strides[1],
|
||||
|
|
|
@ -159,7 +159,14 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
g.gradientForVariable().put(s, dl4jGrad);
|
||||
}
|
||||
|
||||
dLdIn = sameDiff.grad(INPUT_KEY).getArr();
|
||||
SDVariable v = sameDiff.grad(INPUT_KEY);
|
||||
dLdIn = v.getArr();
|
||||
|
||||
if(dLdIn == null && fn.getGradPlaceholderName().equals(v.getVarName())){
|
||||
//Edge case with lambda layers like identity: SameDiff doesn't store the placeholders
|
||||
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
|
||||
dLdIn = epsilon;
|
||||
}
|
||||
}
|
||||
|
||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||
|
|
|
@ -56,7 +56,7 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
DECLARE_SHAPE_FN(crop_and_resize) {
|
||||
auto in = inputShape->at(1);
|
||||
auto in = inputShape->at(0);
|
||||
|
||||
Nd4jLong outputShape[4];
|
||||
|
||||
|
|
|
@ -2014,7 +2014,7 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
|||
|
||||
for (Nd4jLong kh = hstart; kh < hend; kh += iStep2)
|
||||
for (Nd4jLong kw = wstart; kw < wend; kw += iStep3)
|
||||
pgI[kh + kw] += valO * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kh + kw]), extraParam0 - 1.f);
|
||||
pgI[kh + kw] += valO * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kh + kw]), extraParam0 - 1.f) * nd4j::math::nd4j_sgn<T,T>(pIn[kh + kw]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,22 +16,15 @@
|
|||
|
||||
package org.nd4j.autodiff.samediff;
|
||||
|
||||
import static org.nd4j.autodiff.util.TrainingUtils.getSingleOutput;
|
||||
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
|
||||
|
||||
import com.google.common.collect.HashBasedTable;
|
||||
import com.google.common.collect.Table;
|
||||
import com.google.common.primitives.Ints;
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.rits.cloning.Cloner;
|
||||
import com.rits.cloning.IFastCloner;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.apache.commons.io.output.CloseShieldOutputStream;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
|
||||
import org.nd4j.autodiff.execution.conf.OutputMode;
|
||||
|
@ -49,7 +42,6 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.graph.*;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
|
@ -59,7 +51,6 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
|||
import org.nd4j.linalg.api.ops.impl.controlflow.If;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.While;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
|
||||
|
@ -76,7 +67,6 @@ import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
|
|||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
import org.nd4j.linalg.exception.ND4JException;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
|
||||
|
@ -89,11 +79,11 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
||||
import org.nd4j.linalg.util.ND4JFileUtils;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import org.nd4j.weightinit.WeightInitScheme;
|
||||
import org.nd4j.weightinit.impl.ConstantInitScheme;
|
||||
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
|
||||
import org.nd4j.weightinit.impl.ZeroInitScheme;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
||||
import java.io.*;
|
||||
import java.lang.reflect.Method;
|
||||
|
@ -101,10 +91,11 @@ import java.nio.ByteBuffer;
|
|||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipFile;
|
||||
import java.util.zip.ZipOutputStream;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
import static org.nd4j.autodiff.util.TrainingUtils.getSingleOutput;
|
||||
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
|
||||
|
||||
/**
|
||||
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
|
||||
|
|
|
@ -3692,7 +3692,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
*/
|
||||
@Override
|
||||
public INDArray divi(INDArray other) {
|
||||
validateNumericalArray("divi", false);
|
||||
return divi(other, this);
|
||||
}
|
||||
|
||||
|
@ -3706,30 +3705,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
@Override
|
||||
public INDArray divi(INDArray other, INDArray result) {
|
||||
validateNumericalArray("divi", false);
|
||||
if (other.isScalar()) {
|
||||
return divi(other.getDouble(0), result);
|
||||
}
|
||||
|
||||
if (isScalar()) {
|
||||
return other.rdivi(getDouble(0), result);
|
||||
}
|
||||
|
||||
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
|
||||
val outShape = Shape.broadcastOutputShape(this.shape(), other.shape());
|
||||
Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape()));
|
||||
|
||||
Nd4j.exec(new DivOp(new INDArray[]{this, other}, new INDArray[]{result}));
|
||||
|
||||
return result;
|
||||
} else if(!Shape.shapeEquals(this.shape(),other.shape())) {
|
||||
int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
|
||||
Nd4j.getExecutioner().exec(new BroadcastDivOp(this,other,result,broadcastDimensions));
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
LinAlgExceptions.assertSameShape(other, result);
|
||||
Nd4j.getExecutioner().exec(new OldDivOp(this, other, result));
|
||||
Shape.assertBroadcastable("divi", this, other, result);
|
||||
Nd4j.exec(new DivOp(this, other, result));
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -3741,7 +3718,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
*/
|
||||
@Override
|
||||
public INDArray muli(INDArray other) {
|
||||
validateNumericalArray("muli", false);
|
||||
return muli(other, this);
|
||||
}
|
||||
|
||||
|
@ -3755,29 +3731,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
@Override
|
||||
public INDArray muli(INDArray other, INDArray result) {
|
||||
validateNumericalArray("muli", false);
|
||||
if (other.isScalar()) {
|
||||
return muli(other.getDouble(0), result);
|
||||
}
|
||||
if (isScalar()) {
|
||||
return other.muli(getDouble(0), result);
|
||||
}
|
||||
|
||||
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
|
||||
val outShape = Shape.broadcastOutputShape(this.shape(), other.shape());
|
||||
Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape()));
|
||||
|
||||
Nd4j.exec(new MulOp(new INDArray[]{this, other}, new INDArray[]{result}));
|
||||
|
||||
return result;
|
||||
} else if(!Shape.shapeEquals(this.shape(),other.shape())) {
|
||||
int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
|
||||
Nd4j.getExecutioner().exec(new BroadcastMulOp(this,other,result,broadcastDimensions));
|
||||
return result;
|
||||
}
|
||||
|
||||
LinAlgExceptions.assertSameShape(other, result);
|
||||
|
||||
Nd4j.getExecutioner().exec(new OldMulOp(this, other, result));
|
||||
Shape.assertBroadcastable("muli", this, other, result);
|
||||
Nd4j.exec(new MulOp(this, other, result));
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -3802,31 +3757,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
@Override
|
||||
public INDArray subi(INDArray other, INDArray result) {
|
||||
validateNumericalArray("subi", false);
|
||||
if (other.isScalar()) {
|
||||
return subi(other.getDouble(0), result);
|
||||
}
|
||||
if (isScalar()) {
|
||||
return other.rsubi(getDouble(0), result);
|
||||
}
|
||||
|
||||
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
|
||||
val outShape = Shape.broadcastOutputShape(this.shape(), other.shape());
|
||||
Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape()));
|
||||
|
||||
Nd4j.exec(new SubOp(new INDArray[]{this, other}, new INDArray[]{result}));
|
||||
|
||||
return result;
|
||||
} else if(!Shape.shapeEquals(this.shape(),other.shape())) {
|
||||
int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
|
||||
Nd4j.getExecutioner().exec(new BroadcastSubOp(this,other,result,broadcastDimensions));
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
LinAlgExceptions.assertSameShape(other, result);
|
||||
|
||||
|
||||
Nd4j.getExecutioner().exec(new OldSubOp(this, other,result));
|
||||
Shape.assertBroadcastable("subi", this, other, result);
|
||||
Nd4j.exec(new SubOp(this, other, result));
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -3851,33 +3783,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
@Override
|
||||
public INDArray addi(INDArray other, INDArray result) {
|
||||
validateNumericalArray("addi", false);
|
||||
if (other.isScalar()) {
|
||||
return this.addi(other.getDouble(0), result);
|
||||
}
|
||||
|
||||
if (isScalar()) {
|
||||
return other.addi(getDouble(0), result);
|
||||
}
|
||||
|
||||
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
|
||||
val outShape = Shape.broadcastOutputShape(this.shape(), other.shape());
|
||||
Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape()));
|
||||
|
||||
Nd4j.exec(new AddOp(new INDArray[]{this, other}, new INDArray[]{result}));
|
||||
|
||||
Shape.assertBroadcastable("addi", this, other, result);
|
||||
Nd4j.exec(new AddOp(this, other, result));
|
||||
return result;
|
||||
} else if(!Shape.shapeEquals(this.shape(),other.shape())) {
|
||||
int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
|
||||
result = Nd4j.createUninitialized(this.dataType(), Shape.broadcastOutputShape(this.shape(),other.shape()));
|
||||
Nd4j.getExecutioner().exec(new BroadcastAddOp(this,other,result,broadcastDimensions));
|
||||
return result;
|
||||
} else {
|
||||
|
||||
LinAlgExceptions.assertSameShape(this, other, result);
|
||||
|
||||
Nd4j.getExecutioner().exec(new OldAddOp(this, other, result));
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -3954,7 +3862,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
@Override
|
||||
public INDArray rdivi(INDArray other, INDArray result) {
|
||||
validateNumericalArray("rdivi", false);
|
||||
return other.divi(this, result);
|
||||
Shape.assertBroadcastable("rdivi", this, other, result);
|
||||
Nd4j.exec(new RDivOp(this, other, result));
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -4003,33 +3913,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
@Override
|
||||
public INDArray rsubi(INDArray other, INDArray result) {
|
||||
validateNumericalArray("rsubi", false);
|
||||
if (other.isScalar()) {
|
||||
return this.rsubi(other.getDouble(0), result);
|
||||
}
|
||||
|
||||
if (isScalar()) {
|
||||
return other.rsubi(getDouble(0), result);
|
||||
}
|
||||
|
||||
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
|
||||
val outShape = Shape.broadcastOutputShape(this.shape(), other.shape());
|
||||
Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape()));
|
||||
|
||||
Nd4j.exec(new RSubOp(new INDArray[]{this, other}, new INDArray[]{result}));
|
||||
|
||||
Shape.assertBroadcastable("rsubi", this, other, result);
|
||||
Nd4j.exec(new RSubOp(this, other, result));
|
||||
return result;
|
||||
} else if(!Shape.shapeEquals(this.shape(),other.shape())) {
|
||||
int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
|
||||
result = Nd4j.createUninitialized(this.dataType(), Shape.broadcastOutputShape(this.shape(),other.shape()));
|
||||
Nd4j.getExecutioner().exec(new BroadcastRSubOp(this,other,result,broadcastDimensions));
|
||||
return result;
|
||||
} else {
|
||||
|
||||
LinAlgExceptions.assertSameShape(this, other, result);
|
||||
|
||||
Nd4j.getExecutioner().exec(new OldRSubOp(this, other, result));
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -6796,6 +6682,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
throw new IllegalStateException("Cannot perform operation " + opName + " on empty array with datatype " + dataType());
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public boolean closeable() {
|
||||
if (released || isAttached())
|
||||
|
|
|
@ -38,6 +38,10 @@ public class AddOp extends BaseDynamicTransformOp {
|
|||
super(sameDiff, args, inPlace);
|
||||
}
|
||||
|
||||
public AddOp(INDArray first, INDArray second, INDArray result){
|
||||
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
|
||||
}
|
||||
|
||||
public AddOp(INDArray[] inputs, INDArray[] outputs) {
|
||||
super(inputs, outputs);
|
||||
}
|
||||
|
|
|
@ -37,6 +37,10 @@ public class DivOp extends BaseDynamicTransformOp {
|
|||
super(sameDiff, args, inPlace);
|
||||
}
|
||||
|
||||
public DivOp(INDArray first, INDArray second, INDArray result){
|
||||
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
|
||||
}
|
||||
|
||||
public DivOp( INDArray[] inputs, INDArray[] outputs) {
|
||||
super(inputs, outputs);
|
||||
}
|
||||
|
|
|
@ -37,6 +37,10 @@ public class MulOp extends BaseDynamicTransformOp {
|
|||
super(sameDiff, args, inPlace);
|
||||
}
|
||||
|
||||
public MulOp(INDArray first, INDArray second, INDArray result){
|
||||
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
|
||||
}
|
||||
|
||||
public MulOp( INDArray[] inputs, INDArray[] outputs) {
|
||||
super(inputs, outputs);
|
||||
}
|
||||
|
|
|
@ -28,10 +28,9 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Add operation for two operands
|
||||
*
|
||||
* @author Adam Gibson
|
||||
* @deprecated Use {@link AddOp}
|
||||
*/
|
||||
@Deprecated
|
||||
public class OldAddOp extends BaseTransformAnyOp {
|
||||
public OldAddOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
||||
super(sameDiff, i_v1, i_v2);
|
||||
|
|
|
@ -28,10 +28,9 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Division operation
|
||||
*
|
||||
* @author Adam Gibson
|
||||
* @deprecated Use {@link DivOp}
|
||||
*/
|
||||
@Deprecated
|
||||
public class OldDivOp extends BaseTransformAnyOp {
|
||||
public OldDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
||||
super(sameDiff, i_v1, i_v2);
|
||||
|
|
|
@ -28,10 +28,9 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Multiplication operation
|
||||
*
|
||||
* @author Adam Gibson
|
||||
* @deprecated Use {@link MulOp}
|
||||
*/
|
||||
@Deprecated
|
||||
public class OldMulOp extends BaseTransformAnyOp {
|
||||
public OldMulOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
||||
super(sameDiff, i_v1, i_v2);
|
||||
|
|
|
@ -28,10 +28,9 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* OldReverse Division operation
|
||||
*
|
||||
* @author Adam Gibson
|
||||
* @deprecated Use {@link RDivOp}
|
||||
*/
|
||||
@Deprecated
|
||||
public class OldRDivOp extends BaseTransformAnyOp {
|
||||
public OldRDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
||||
super(sameDiff, i_v1, i_v2);
|
||||
|
|
|
@ -26,10 +26,9 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Division operation
|
||||
*
|
||||
* @author Adam Gibson
|
||||
* @deprecated Use {@link RSubOp}
|
||||
*/
|
||||
@Deprecated
|
||||
public class OldRSubOp extends BaseTransformAnyOp {
|
||||
public OldRSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
||||
super(sameDiff, i_v1, i_v2);
|
||||
|
|
|
@ -28,10 +28,9 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Division operation
|
||||
*
|
||||
* @author Adam Gibson
|
||||
* @deprecated Use {@link SubOp}
|
||||
*/
|
||||
@Deprecated
|
||||
public class OldSubOp extends BaseTransformAnyOp {
|
||||
public OldSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
||||
super(sameDiff, i_v1, i_v2);
|
||||
|
|
|
@ -38,6 +38,10 @@ public class RDivOp extends BaseDynamicTransformOp {
|
|||
super(sameDiff, args, inPlace);
|
||||
}
|
||||
|
||||
public RDivOp(INDArray first, INDArray second, INDArray result){
|
||||
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
|
||||
}
|
||||
|
||||
public RDivOp( INDArray[] inputs, INDArray[] outputs) {
|
||||
super(inputs, outputs);
|
||||
}
|
||||
|
|
|
@ -45,6 +45,10 @@ public class RSubOp extends BaseDynamicTransformOp {
|
|||
this(sameDiff, new SDVariable[]{i_v1, i_v2}, inPlace);
|
||||
}
|
||||
|
||||
public RSubOp(INDArray first, INDArray second, INDArray result){
|
||||
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
|
||||
}
|
||||
|
||||
public RSubOp() {}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -37,6 +37,10 @@ public class SubOp extends BaseDynamicTransformOp {
|
|||
super(sameDiff, args, inPlace);
|
||||
}
|
||||
|
||||
public SubOp(INDArray first, INDArray second, INDArray result){
|
||||
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
|
||||
}
|
||||
|
||||
public SubOp( INDArray[] inputs, INDArray[] outputs) {
|
||||
super(inputs, outputs);
|
||||
}
|
||||
|
|
|
@ -251,6 +251,41 @@ public class Shape {
|
|||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Assert that the broadcast operation {@code result = first.op(second)} is valid, given the
|
||||
* shapes of first, second, and result.<br>
|
||||
* Throws an exception otherwise
|
||||
*
|
||||
* @param op Name of the operation
|
||||
* @param first First array
|
||||
* @param second Second array
|
||||
* @param result Result arrray.
|
||||
*/
|
||||
public static void assertBroadcastable(String op, INDArray first, INDArray second, INDArray result){
|
||||
long[] fShape = first.shape();
|
||||
long[] sShape = second.shape();
|
||||
Preconditions.checkState(Shape.areShapesBroadcastable(fShape, sShape),
|
||||
"Cannot perform operation \"%s\" - shapes are not equal and are not broadcastable." +
|
||||
"first.shape=%s, second.shape=%s", op, fShape, sShape);
|
||||
|
||||
long[] outShape = Shape.broadcastOutputShape(fShape, sShape);
|
||||
if (!Arrays.equals(outShape, result.shape())) {
|
||||
//Two cases
|
||||
// 1. x.addi(y)
|
||||
// 2. x.addi(y, z)
|
||||
|
||||
String extra = "";
|
||||
if(first == result){
|
||||
extra = ".\nIn-place operations like x." + op + "(y) can only be performed when x and y have the same shape," +
|
||||
" or x and y are broadcastable with x.shape() == broadcastShape(x,y)";
|
||||
}
|
||||
|
||||
throw new IllegalStateException("Cannot perform in-place operation \"" + op + "\": result array shape does" +
|
||||
" not match the broadcast operation output shape: " + Arrays.toString(fShape) + "." + op + "(" +
|
||||
Arrays.toString(sShape) + ") != " + Arrays.toString(result.shape()) + extra);
|
||||
}
|
||||
}
|
||||
|
||||
public static long[] broadcastOutputShape(long[] left,long[] right) {
|
||||
if (containsZeros(left))
|
||||
return left;
|
||||
|
|
|
@ -80,6 +80,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse;
|
|||
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
|
@ -6222,7 +6223,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
assertEquals(expected, b.rdiv(2));
|
||||
assertEquals(expected2, d.rdivColumnVector(c));
|
||||
|
||||
assertEquals(expected, b.rdiv(Nd4j.scalar(2)));
|
||||
assertEquals(expected, b.rdiv(Nd4j.scalar(2.0)));
|
||||
assertEquals(expected, b.rdivColumnVector(Nd4j.scalar(2)));
|
||||
}
|
||||
|
||||
|
@ -7958,7 +7959,11 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
c.addOutputArgument(out);
|
||||
Nd4j.getExecutioner().exec(c);
|
||||
|
||||
assertEquals(Nd4j.createFromArray(1f, 3f, 4f), out);
|
||||
List<LongShapeDescriptor> l = c.calculateOutputShape();
|
||||
|
||||
System.out.println(Arrays.toString(l.get(0).getShape()));
|
||||
|
||||
//from [4,4,3] to [2,4,6] then crop to [2,4,5]
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
/**
|
||||
* @author raver119@gmail.com
|
||||
|
@ -122,42 +123,42 @@ public class BasicBroadcastTests extends BaseNd4jTest {
|
|||
assertEquals(e, z);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void basicBroadcastFailureTest_1() {
|
||||
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
|
||||
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
|
||||
val z = x.subi(y);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void basicBroadcastFailureTest_2() {
|
||||
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
|
||||
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
|
||||
val z = x.divi(y);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void basicBroadcastFailureTest_3() {
|
||||
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
|
||||
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
|
||||
val z = x.muli(y);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void basicBroadcastFailureTest_4() {
|
||||
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
|
||||
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
|
||||
val z = x.addi(y);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void basicBroadcastFailureTest_5() {
|
||||
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
|
||||
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
|
||||
val z = x.rsubi(y);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void basicBroadcastFailureTest_6() {
|
||||
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
|
||||
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
|
||||
|
@ -206,7 +207,7 @@ public class BasicBroadcastTests extends BaseNd4jTest {
|
|||
assertEquals(y, z);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void emptyBroadcastTest_2() {
|
||||
val x = Nd4j.create(DataType.FLOAT, 1, 2);
|
||||
val y = Nd4j.create(DataType.FLOAT, 0, 2);
|
||||
|
@ -226,6 +227,67 @@ public class BasicBroadcastTests extends BaseNd4jTest {
|
|||
assertEquals(y, z);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testValidInvalidBroadcast(){
|
||||
INDArray x = Nd4j.rand(3,1);
|
||||
INDArray y = Nd4j.create(3, 4);
|
||||
|
||||
x.add(y);
|
||||
y.addi(x);
|
||||
try {
|
||||
x.addi(y);
|
||||
} catch (Exception e){
|
||||
String s = e.getMessage();
|
||||
assertTrue(s, s.contains("broadcast") && s.contains("shape"));
|
||||
}
|
||||
|
||||
x.sub(y);
|
||||
y.subi(x);
|
||||
try {
|
||||
x.subi(y);
|
||||
} catch (Exception e){
|
||||
String s = e.getMessage();
|
||||
assertTrue(s, s.contains("broadcast") && s.contains("shape"));
|
||||
}
|
||||
|
||||
x.mul(y);
|
||||
y.muli(x);
|
||||
try {
|
||||
x.muli(y);
|
||||
} catch (Exception e){
|
||||
String s = e.getMessage();
|
||||
assertTrue(s, s.contains("broadcast") && s.contains("shape"));
|
||||
}
|
||||
|
||||
x.div(y);
|
||||
y.divi(x);
|
||||
try {
|
||||
x.divi(y);
|
||||
} catch (Exception e){
|
||||
String s = e.getMessage();
|
||||
assertTrue(s, s.contains("broadcast") && s.contains("shape"));
|
||||
}
|
||||
|
||||
x.rsub(y);
|
||||
y.rsubi(x);
|
||||
try {
|
||||
x.rsubi(y);
|
||||
} catch (Exception e){
|
||||
String s = e.getMessage();
|
||||
assertTrue(s, s.contains("broadcast") && s.contains("shape"));
|
||||
}
|
||||
|
||||
x.rdiv(y);
|
||||
y.rdivi(x);
|
||||
try {
|
||||
x.rdivi(y);
|
||||
} catch (Exception e){
|
||||
String s = e.getMessage();
|
||||
assertTrue(s, s.contains("broadcast") && s.contains("shape"));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
|
@ -7,8 +7,7 @@ import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
|||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class CropAndResizeDataSetPreProcessorTest {
|
||||
|
||||
|
@ -93,10 +92,7 @@ public class CropAndResizeDataSetPreProcessorTest {
|
|||
// Assert
|
||||
INDArray results = ds.getFeatures();
|
||||
long[] shape = results.shape();
|
||||
assertEquals(1, shape[0]);
|
||||
assertEquals(4, shape[1]);
|
||||
assertEquals(3, shape[2]);
|
||||
assertEquals(3, shape[3]);
|
||||
assertArrayEquals(new long[]{1, 4, 3, 3}, shape);
|
||||
|
||||
// Test a few values
|
||||
assertEquals(55.0, results.getDouble(0, 0, 0, 0), 0.0);
|
||||
|
|
Loading…
Reference in New Issue