fix for eclipse#8087 (#129)
* fix for #8087 Signed-off-by: Robert Altena <Rob@Ra-ai.com> * remove commented code. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * removing trueScalar. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * remove tueScalar. Signed-off-by: Robert Altena <Rob@Ra-ai.com>master
parent
10d676e0b8
commit
38310777ee
|
@ -1280,62 +1280,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
||||||
return create(new double[] {value}, new int[0], new int[0], offset);
|
return create(new double[] {value}, new int[0], new int[0], offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray trueScalar(DataType dataType, Number value) {
|
|
||||||
val ws = Nd4j.getMemoryManager().getCurrentWorkspace();
|
|
||||||
|
|
||||||
switch (dataType) {
|
|
||||||
case DOUBLE:
|
|
||||||
return create(new double[] {value.doubleValue()}, new long[] {}, new long[] {}, dataType, ws);
|
|
||||||
case FLOAT:
|
|
||||||
return create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, dataType, ws);
|
|
||||||
case BFLOAT16:
|
|
||||||
return create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, dataType, ws);
|
|
||||||
case HALF:
|
|
||||||
return create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, dataType, ws);
|
|
||||||
case UINT32:
|
|
||||||
case INT:
|
|
||||||
return create(new int[] {value.intValue()}, new long[] {}, new long[] {}, dataType, ws);
|
|
||||||
case UINT64:
|
|
||||||
case LONG:
|
|
||||||
return create(new long[] {value.longValue()}, new long[] {}, new long[] {}, dataType, ws);
|
|
||||||
case UINT16:
|
|
||||||
case SHORT:
|
|
||||||
return create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, dataType, ws);
|
|
||||||
case BYTE:
|
|
||||||
return create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, dataType, ws);
|
|
||||||
case UBYTE:
|
|
||||||
return create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, dataType, ws);
|
|
||||||
case BOOL:
|
|
||||||
val b = value.byteValue();
|
|
||||||
val arr = create(new byte[] {b}, new long[] {}, new long[] {}, dataType, ws);
|
|
||||||
return arr;
|
|
||||||
|
|
||||||
default:
|
|
||||||
throw new UnsupportedOperationException("Unsupported data type used: " + dataType);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray trueScalar(Number value) {
|
|
||||||
val ws = Nd4j.getMemoryManager().getCurrentWorkspace();
|
|
||||||
|
|
||||||
if (value instanceof Double || value instanceof AtomicDouble) /* note that org.nd4j.linalg.primitives.AtomicDouble extends com.google.common.util.concurrent.AtomicDouble */
|
|
||||||
return create(new double[] {value.doubleValue()}, new long[] {}, new long[] {}, DataType.DOUBLE, ws);
|
|
||||||
else if (value instanceof Float)
|
|
||||||
return create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, DataType.FLOAT, ws);
|
|
||||||
else if (value instanceof Long || value instanceof AtomicLong)
|
|
||||||
return create(new long[] {value.longValue()}, new long[] {}, new long[] {}, DataType.LONG, ws);
|
|
||||||
else if (value instanceof Integer || value instanceof AtomicInteger)
|
|
||||||
return create(new int[] {value.intValue()}, new long[] {}, new long[] {}, DataType.INT, ws);
|
|
||||||
else if (value instanceof Short)
|
|
||||||
return create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, DataType.SHORT, ws);
|
|
||||||
else if (value instanceof Byte)
|
|
||||||
return create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, DataType.BYTE, ws);
|
|
||||||
else
|
|
||||||
throw new UnsupportedOperationException("Unsupported data type: [" + value.getClass().getSimpleName() + "]");
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray trueVector(boolean[] data) {
|
public INDArray trueVector(boolean[] data) {
|
||||||
return create(data, new long[] {data.length}, new long[]{1}, DataType.BOOL, Nd4j.getMemoryManager().getCurrentWorkspace());
|
return create(data, new long[] {data.length}, new long[]{1}, DataType.BOOL, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||||
}
|
}
|
||||||
|
@ -1364,8 +1308,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
||||||
return create(data, new long[] {data.length}, new long[]{1}, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
|
return create(data, new long[] {data.length}, new long[]{1}, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a scalar nd array with the specified value and offset
|
* Create a scalar nd array with the specified value and offset
|
||||||
*
|
*
|
||||||
|
|
|
@ -990,12 +990,6 @@ public interface NDArrayFactory {
|
||||||
|
|
||||||
INDArray empty(DataType type);
|
INDArray empty(DataType type);
|
||||||
|
|
||||||
@Deprecated
|
|
||||||
INDArray trueScalar(Number value);
|
|
||||||
|
|
||||||
@Deprecated
|
|
||||||
INDArray trueScalar(DataType dataType, Number value);
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
INDArray trueVector(boolean[] data);
|
INDArray trueVector(boolean[] data);
|
||||||
@Deprecated
|
@Deprecated
|
||||||
|
|
|
@ -3736,19 +3736,6 @@ public class Nd4j {
|
||||||
return INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace());
|
return INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* This method creates new 0D INDArray, aka scalar.
|
|
||||||
*
|
|
||||||
* PLEASE NOTE: Temporary method, added to ensure backward compatibility
|
|
||||||
* @param scalar data for INDArray.
|
|
||||||
* @return new INDArray
|
|
||||||
* * @deprecated Use Nd4j.scalar methods, such as {@link #scalar(double)} or {@link #scalar(DataType, Number)}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public static INDArray trueScalar(Number scalar) {
|
|
||||||
return INSTANCE.trueScalar(scalar);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @deprecated Use {@link #createFromArray(boolean...)}
|
* @deprecated Use {@link #createFromArray(boolean...)}
|
||||||
*/
|
*/
|
||||||
|
@ -5087,9 +5074,35 @@ public class Nd4j {
|
||||||
* @param value the value to initialize the scalar with
|
* @param value the value to initialize the scalar with
|
||||||
* @return the created ndarray
|
* @return the created ndarray
|
||||||
*/
|
*/
|
||||||
@SuppressWarnings("deprecation")
|
|
||||||
public static INDArray scalar(DataType dataType, Number value) {
|
public static INDArray scalar(DataType dataType, Number value) {
|
||||||
return INSTANCE.trueScalar(dataType, value);
|
val ws = Nd4j.getMemoryManager().getCurrentWorkspace();
|
||||||
|
|
||||||
|
switch (dataType) {
|
||||||
|
case DOUBLE:
|
||||||
|
return INSTANCE.create(new double[] {value.doubleValue()}, new long[] {}, new long[] {}, dataType, ws);
|
||||||
|
case FLOAT:
|
||||||
|
case BFLOAT16:
|
||||||
|
case HALF:
|
||||||
|
return INSTANCE.create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, dataType, ws);
|
||||||
|
case UINT32:
|
||||||
|
case INT:
|
||||||
|
return INSTANCE.create(new int[] {value.intValue()}, new long[] {}, new long[] {}, dataType, ws);
|
||||||
|
case UINT64:
|
||||||
|
case LONG:
|
||||||
|
return INSTANCE.create(new long[] {value.longValue()}, new long[] {}, new long[] {}, dataType, ws);
|
||||||
|
case UINT16:
|
||||||
|
case SHORT:
|
||||||
|
return INSTANCE.create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, dataType, ws);
|
||||||
|
case BYTE:
|
||||||
|
return INSTANCE.create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, dataType, ws);
|
||||||
|
case UBYTE:
|
||||||
|
return INSTANCE.create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, dataType, ws);
|
||||||
|
case BOOL:
|
||||||
|
return INSTANCE.create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, dataType, ws);
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw new UnsupportedOperationException("Unsupported data type used: " + dataType);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -251,11 +251,6 @@ public class CpuSparseNDArrayFactory extends BaseSparseNDArrayFactory {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray trueScalar(Number value) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray create(float[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
public INDArray create(float[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
||||||
return null;
|
return null;
|
||||||
|
|
|
@ -504,7 +504,7 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
double exp = Nd4j.diag(in).sumNumber().doubleValue();
|
double exp = Nd4j.diag(in).sumNumber().doubleValue();
|
||||||
|
|
||||||
TestCase tc = new TestCase(sd)
|
TestCase tc = new TestCase(sd)
|
||||||
.expected(trace, Nd4j.trueScalar(exp))
|
.expected(trace, Nd4j.scalar(exp))
|
||||||
.testName(Arrays.toString(inShape));
|
.testName(Arrays.toString(inShape));
|
||||||
|
|
||||||
String err = OpValidation.validate(tc);
|
String err = OpValidation.validate(tc);
|
||||||
|
@ -1296,7 +1296,7 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
if(shape.length > 0){
|
if(shape.length > 0){
|
||||||
arr = Nd4j.rand(shape);
|
arr = Nd4j.rand(shape);
|
||||||
} else {
|
} else {
|
||||||
arr = Nd4j.trueScalar(Nd4j.rand(new int[]{1,1}).getDouble(0));
|
arr = Nd4j.scalar(Nd4j.rand(new int[]{1,1}).getDouble(0));
|
||||||
}
|
}
|
||||||
SDVariable var = sd.var("in", arr);
|
SDVariable var = sd.var("in", arr);
|
||||||
SDVariable xLike;
|
SDVariable xLike;
|
||||||
|
@ -1388,7 +1388,7 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
INDArray inArr;
|
INDArray inArr;
|
||||||
if (shape == null) {
|
if (shape == null) {
|
||||||
inArr = Nd4j.trueScalar(1.0);
|
inArr = Nd4j.scalar(1.0);
|
||||||
} else {
|
} else {
|
||||||
inArr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(shape);
|
inArr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(shape);
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,7 +79,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
if (keepDims) {
|
if (keepDims) {
|
||||||
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
||||||
} else {
|
} else {
|
||||||
dLdOut = Nd4j.trueScalar(0.5);
|
dLdOut = Nd4j.scalar(0.5);
|
||||||
}
|
}
|
||||||
INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), 0.5);
|
INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), 0.5);
|
||||||
INDArray dLdIn = Nd4j.createUninitialized(3, 4);
|
INDArray dLdIn = Nd4j.createUninitialized(3, 4);
|
||||||
|
@ -164,7 +164,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
if (keepDims) {
|
if (keepDims) {
|
||||||
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
||||||
} else {
|
} else {
|
||||||
dLdOut = Nd4j.trueScalar(0.5);
|
dLdOut = Nd4j.scalar(0.5);
|
||||||
}
|
}
|
||||||
INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), 0.5 / preReduceInput.length());
|
INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), 0.5 / preReduceInput.length());
|
||||||
INDArray dLdIn = Nd4j.createUninitialized(3, 4);
|
INDArray dLdIn = Nd4j.createUninitialized(3, 4);
|
||||||
|
@ -178,7 +178,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMeanBP_Rank1() {
|
public void testMeanBP_Rank1() {
|
||||||
INDArray dLdOut = Nd4j.trueScalar(0.5);
|
INDArray dLdOut = Nd4j.scalar(0.5);
|
||||||
INDArray preReduceInput = Nd4j.create(new double[]{2,3,4}, new long[]{3});
|
INDArray preReduceInput = Nd4j.create(new double[]{2,3,4}, new long[]{3});
|
||||||
INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5/3);
|
INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5/3);
|
||||||
|
|
||||||
|
@ -261,7 +261,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
if (keepDims) {
|
if (keepDims) {
|
||||||
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
||||||
} else {
|
} else {
|
||||||
dLdOut = Nd4j.trueScalar(0.5);
|
dLdOut = Nd4j.scalar(0.5);
|
||||||
}
|
}
|
||||||
INDArray dLdInExpected = Nd4j.zeros(preReduceInput.shape());
|
INDArray dLdInExpected = Nd4j.zeros(preReduceInput.shape());
|
||||||
dLdInExpected.putScalar(new int[]{2, 2}, 0.5); //Minimum value: position at [2,2]
|
dLdInExpected.putScalar(new int[]{2, 2}, 0.5); //Minimum value: position at [2,2]
|
||||||
|
@ -343,7 +343,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
if (keepDims) {
|
if (keepDims) {
|
||||||
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
||||||
} else {
|
} else {
|
||||||
dLdOut = Nd4j.trueScalar(0.5);
|
dLdOut = Nd4j.scalar(0.5);
|
||||||
}
|
}
|
||||||
INDArray dLdInExpected = Nd4j.zeros(preReduceInput.shape());
|
INDArray dLdInExpected = Nd4j.zeros(preReduceInput.shape());
|
||||||
dLdInExpected.putScalar(new int[]{2, 2}, 0.5); //Maximum value: position at [2,2]
|
dLdInExpected.putScalar(new int[]{2, 2}, 0.5); //Maximum value: position at [2,2]
|
||||||
|
@ -415,7 +415,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
if (keepDims) {
|
if (keepDims) {
|
||||||
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
||||||
} else {
|
} else {
|
||||||
dLdOut = Nd4j.trueScalar(0.5);
|
dLdOut = Nd4j.scalar(0.5);
|
||||||
}
|
}
|
||||||
double prod = preReduceInput.prodNumber().doubleValue();
|
double prod = preReduceInput.prodNumber().doubleValue();
|
||||||
INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), prod).divi(preReduceInput).muli(0.5);
|
INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), prod).divi(preReduceInput).muli(0.5);
|
||||||
|
@ -500,7 +500,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
if (keepDims) {
|
if (keepDims) {
|
||||||
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
||||||
} else {
|
} else {
|
||||||
dLdOut = Nd4j.trueScalar(0.5);
|
dLdOut = Nd4j.scalar(0.5);
|
||||||
}
|
}
|
||||||
|
|
||||||
double stdev = preReduceInput.stdNumber(biasCorrected).doubleValue();
|
double stdev = preReduceInput.stdNumber(biasCorrected).doubleValue();
|
||||||
|
@ -523,7 +523,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testStdevBP_Rank1() {
|
public void testStdevBP_Rank1() {
|
||||||
INDArray dLdOut = Nd4j.trueScalar(0.5);
|
INDArray dLdOut = Nd4j.scalar(0.5);
|
||||||
INDArray preReduceInput = Nd4j.create(new double[]{2,3,4}, new long[]{3});
|
INDArray preReduceInput = Nd4j.create(new double[]{2,3,4}, new long[]{3});
|
||||||
double stdev = preReduceInput.stdNumber(true).doubleValue();
|
double stdev = preReduceInput.stdNumber(true).doubleValue();
|
||||||
double mean = preReduceInput.meanNumber().doubleValue();
|
double mean = preReduceInput.meanNumber().doubleValue();
|
||||||
|
@ -602,7 +602,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
if (keepDims) {
|
if (keepDims) {
|
||||||
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
||||||
} else {
|
} else {
|
||||||
dLdOut = Nd4j.trueScalar(0.5);
|
dLdOut = Nd4j.scalar(0.5);
|
||||||
}
|
}
|
||||||
|
|
||||||
double var = preReduceInput.var(biasCorrected).getDouble(0);
|
double var = preReduceInput.var(biasCorrected).getDouble(0);
|
||||||
|
@ -811,7 +811,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
if (keepDims) {
|
if (keepDims) {
|
||||||
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
||||||
} else {
|
} else {
|
||||||
dLdOut = Nd4j.trueScalar(0.5);
|
dLdOut = Nd4j.scalar(0.5);
|
||||||
}
|
}
|
||||||
INDArray dLdInExpected = sgn.muli(0.5);
|
INDArray dLdInExpected = sgn.muli(0.5);
|
||||||
INDArray dLdIn = Nd4j.createUninitialized(3, 4);
|
INDArray dLdIn = Nd4j.createUninitialized(3, 4);
|
||||||
|
@ -873,7 +873,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
if (keepDims) {
|
if (keepDims) {
|
||||||
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
|
||||||
} else {
|
} else {
|
||||||
dLdOut = Nd4j.trueScalar(0.5);
|
dLdOut = Nd4j.scalar(0.5);
|
||||||
}
|
}
|
||||||
INDArray dLdInExpected = sgn.mul(max).mul(0.5);
|
INDArray dLdInExpected = sgn.mul(max).mul(0.5);
|
||||||
INDArray dLdIn = Nd4j.createUninitialized(3, 4);
|
INDArray dLdIn = Nd4j.createUninitialized(3, 4);
|
||||||
|
|
|
@ -234,12 +234,12 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
case 10:
|
case 10:
|
||||||
loss = sd.math().countNonZero("loss", input);
|
loss = sd.math().countNonZero("loss", input);
|
||||||
name = "countNonZero";
|
name = "countNonZero";
|
||||||
tc.expectedOutput("loss", Nd4j.trueScalar(inputArr.length()));
|
tc.expectedOutput("loss", Nd4j.scalar(inputArr.length()));
|
||||||
break;
|
break;
|
||||||
case 11:
|
case 11:
|
||||||
loss = sd.math().countZero("loss", input);
|
loss = sd.math().countZero("loss", input);
|
||||||
name = "countZero";
|
name = "countZero";
|
||||||
tc.expectedOutput("loss", Nd4j.trueScalar(0));
|
tc.expectedOutput("loss", Nd4j.scalar(0));
|
||||||
break;
|
break;
|
||||||
case 12:
|
case 12:
|
||||||
loss = sd.math().amax("loss", input);
|
loss = sd.math().amax("loss", input);
|
||||||
|
@ -280,21 +280,21 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
name = "sqnorm";
|
name = "sqnorm";
|
||||||
loss = sd.squaredNorm("loss", input);
|
loss = sd.squaredNorm("loss", input);
|
||||||
double norm2 = inputArr.norm2Number().doubleValue();
|
double norm2 = inputArr.norm2Number().doubleValue();
|
||||||
tc.expected("loss", Nd4j.trueScalar(norm2 * norm2));
|
tc.expected("loss", Nd4j.scalar(norm2 * norm2));
|
||||||
break;
|
break;
|
||||||
case 19:
|
case 19:
|
||||||
inputArr = Nd4j.rand(minibatch, nOut);
|
inputArr = Nd4j.rand(minibatch, nOut);
|
||||||
name = "logEntropy";
|
name = "logEntropy";
|
||||||
loss = sd.math().logEntropy("loss", input);
|
loss = sd.math().logEntropy("loss", input);
|
||||||
double logEntropy = inputArr.logEntropyNumber().doubleValue();
|
double logEntropy = inputArr.logEntropyNumber().doubleValue();
|
||||||
tc.expected(loss, Nd4j.trueScalar(logEntropy));
|
tc.expected(loss, Nd4j.scalar(logEntropy));
|
||||||
break;
|
break;
|
||||||
case 20:
|
case 20:
|
||||||
inputArr = Nd4j.rand(minibatch, nOut);
|
inputArr = Nd4j.rand(minibatch, nOut);
|
||||||
name = "shannonEntropy";
|
name = "shannonEntropy";
|
||||||
loss = sd.math().shannonEntropy("loss", input);
|
loss = sd.math().shannonEntropy("loss", input);
|
||||||
double shannonEntropy = inputArr.shannonEntropyNumber().doubleValue();
|
double shannonEntropy = inputArr.shannonEntropyNumber().doubleValue();
|
||||||
tc.expected(loss, Nd4j.trueScalar(shannonEntropy));
|
tc.expected(loss, Nd4j.scalar(shannonEntropy));
|
||||||
if (OpValidationSuite.IGNORE_FAILING) {
|
if (OpValidationSuite.IGNORE_FAILING) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1030,7 +1030,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
//Case 1: shape is provided + scalar
|
//Case 1: shape is provided + scalar
|
||||||
|
|
||||||
sd = SameDiff.create();
|
sd = SameDiff.create();
|
||||||
ia = Nd4j.trueScalar(3.0);
|
ia = Nd4j.scalar(3.0);
|
||||||
in = sd.var(ia);
|
in = sd.var(ia);
|
||||||
constant = sd.constant(in, 3,4,5);
|
constant = sd.constant(in, 3,4,5);
|
||||||
INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0);
|
INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0);
|
||||||
|
@ -1169,7 +1169,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
double d = new LUDecomposition(CheckUtil.convertToApacheMatrix(in)).getDeterminant();
|
double d = new LUDecomposition(CheckUtil.convertToApacheMatrix(in)).getDeterminant();
|
||||||
|
|
||||||
|
|
||||||
INDArray outExp = Nd4j.trueScalar(d);
|
INDArray outExp = Nd4j.scalar(d);
|
||||||
|
|
||||||
String err = OpValidation.validate(new TestCase(sd)
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
.expected(md.getVarName(), outExp));
|
.expected(md.getVarName(), outExp));
|
||||||
|
@ -1193,7 +1193,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(d, d2, 1e-5);
|
assertEquals(d, d2, 1e-5);
|
||||||
|
|
||||||
|
|
||||||
INDArray outExp = Nd4j.trueScalar(d);
|
INDArray outExp = Nd4j.scalar(d);
|
||||||
|
|
||||||
String err = OpValidation.validate(new TestCase(sd)
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
.expected(md.getVarName(), outExp));
|
.expected(md.getVarName(), outExp));
|
||||||
|
@ -1224,7 +1224,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
- a[0][0] * a[1][2] * a[2][1];
|
- a[0][0] * a[1][2] * a[2][1];
|
||||||
assertEquals(d, d2, 1e-6); //Manual calc and Apache commons both match: 0.03589524995561552
|
assertEquals(d, d2, 1e-6); //Manual calc and Apache commons both match: 0.03589524995561552
|
||||||
|
|
||||||
INDArray outExp = Nd4j.trueScalar(d);
|
INDArray outExp = Nd4j.scalar(d);
|
||||||
|
|
||||||
String err = OpValidation.validate(new TestCase(sd)
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
.expected(md.getVarName(), outExp));
|
.expected(md.getVarName(), outExp));
|
||||||
|
@ -1247,7 +1247,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
//System.out.println(d);
|
//System.out.println(d);
|
||||||
|
|
||||||
String err = OpValidation.validate(new TestCase(sd)
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
.expected(md.getVarName(), Nd4j.trueScalar(d)));
|
.expected(md.getVarName(), Nd4j.scalar(d)));
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1792,7 +1792,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
@Test
|
@Test
|
||||||
public void testSplit1(){
|
public void testSplit1(){
|
||||||
INDArray in = Nd4j.linspace(1,10,10).reshape(10);
|
INDArray in = Nd4j.linspace(1,10,10).reshape(10);
|
||||||
INDArray axis = Nd4j.trueScalar(-1);
|
INDArray axis = Nd4j.scalar(-1);
|
||||||
|
|
||||||
INDArray out1 = Nd4j.create(new long[]{5});
|
INDArray out1 = Nd4j.create(new long[]{5});
|
||||||
INDArray out2 = Nd4j.create(new long[]{5});
|
INDArray out2 = Nd4j.create(new long[]{5});
|
||||||
|
|
|
@ -1408,7 +1408,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
public void testScatterOpsScalar(){
|
public void testScatterOpsScalar(){
|
||||||
for(String s : new String[]{"add", "sub", "mul", "div"}) {
|
for(String s : new String[]{"add", "sub", "mul", "div"}) {
|
||||||
INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3);
|
INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3);
|
||||||
INDArray indices = Nd4j.trueScalar(5);
|
INDArray indices = Nd4j.scalar(5);
|
||||||
INDArray upd = Nd4j.create(new double[]{10, 20, 30});
|
INDArray upd = Nd4j.create(new double[]{10, 20, 30});
|
||||||
|
|
||||||
//The non-scalar case works:
|
//The non-scalar case works:
|
||||||
|
@ -1452,7 +1452,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
public void testPad(){
|
public void testPad(){
|
||||||
INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0);
|
INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0);
|
||||||
INDArray pad = Nd4j.create(new double[]{1,1}, new long[]{1,2}).castTo(DataType.LONG);
|
INDArray pad = Nd4j.create(new double[]{1,1}, new long[]{1,2}).castTo(DataType.LONG);
|
||||||
INDArray value = Nd4j.trueScalar(10.0);
|
INDArray value = Nd4j.scalar(10.0);
|
||||||
|
|
||||||
INDArray out = Nd4j.create(new long[]{7});
|
INDArray out = Nd4j.create(new long[]{7});
|
||||||
|
|
||||||
|
|
|
@ -115,7 +115,7 @@ public class ByteOrderTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testScalarEncoding() {
|
public void testScalarEncoding() {
|
||||||
val scalar = Nd4j.trueScalar(2.0f);
|
val scalar = Nd4j.scalar(2.0f);
|
||||||
|
|
||||||
FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0);
|
FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0);
|
||||||
val fb = scalar.toFlatArray(bufferBuilder);
|
val fb = scalar.toFlatArray(bufferBuilder);
|
||||||
|
|
|
@ -797,7 +797,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
||||||
Nd4j.create(1);
|
Nd4j.create(1);
|
||||||
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
|
||||||
assertNotNull(tg);
|
assertNotNull(tg);
|
||||||
val input = Nd4j.trueScalar(4.0);
|
val input = Nd4j.scalar(4.0);
|
||||||
tg.associateArrayWithVariable(input, tg.getVariable("input_1"));
|
tg.associateArrayWithVariable(input, tg.getVariable("input_1"));
|
||||||
|
|
||||||
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_4.fb"));
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_4.fb"));
|
||||||
|
@ -815,7 +815,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
||||||
Nd4j.create(1);
|
Nd4j.create(1);
|
||||||
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
|
||||||
assertNotNull(tg);
|
assertNotNull(tg);
|
||||||
val input = Nd4j.trueScalar(9.0);
|
val input = Nd4j.scalar(9.0);
|
||||||
tg.associateArrayWithVariable(input, tg.getVariable("input_1"));
|
tg.associateArrayWithVariable(input, tg.getVariable("input_1"));
|
||||||
|
|
||||||
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0.fb"));
|
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0.fb"));
|
||||||
|
@ -835,7 +835,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
||||||
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
|
||||||
assertNotNull(tg);
|
assertNotNull(tg);
|
||||||
val input0 = Nd4j.create(2, 2).assign(-4.0);
|
val input0 = Nd4j.create(2, 2).assign(-4.0);
|
||||||
val input1 = Nd4j.trueScalar(1.0);
|
val input1 = Nd4j.scalar(1.0);
|
||||||
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
|
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
|
||||||
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
|
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
|
||||||
|
|
||||||
|
@ -855,7 +855,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
||||||
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
|
||||||
assertNotNull(tg);
|
assertNotNull(tg);
|
||||||
val input0 = Nd4j.create(2, 2).assign(-9.0);
|
val input0 = Nd4j.create(2, 2).assign(-9.0);
|
||||||
val input1 = Nd4j.trueScalar(1.0);
|
val input1 = Nd4j.scalar(1.0);
|
||||||
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
|
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
|
||||||
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
|
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
|
||||||
|
|
||||||
|
@ -964,7 +964,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
||||||
assertNotNull(tg);
|
assertNotNull(tg);
|
||||||
|
|
||||||
val input0 = Nd4j.create(new float[] {1, 2, 3, 4}, new int[] {2, 2});
|
val input0 = Nd4j.create(new float[] {1, 2, 3, 4}, new int[] {2, 2});
|
||||||
val input1 = Nd4j.trueScalar(11f);
|
val input1 = Nd4j.scalar(11f);
|
||||||
|
|
||||||
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
|
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
|
||||||
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
|
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
|
||||||
|
|
|
@ -5880,9 +5880,9 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testScalar_2() {
|
public void testScalar_2() {
|
||||||
val scalar = Nd4j.trueScalar(2.0f);
|
val scalar = Nd4j.scalar(2.0f);
|
||||||
val scalar2 = Nd4j.trueScalar(2.0f);
|
val scalar2 = Nd4j.scalar(2.0f);
|
||||||
val scalar3 = Nd4j.trueScalar(3.0f);
|
val scalar3 = Nd4j.scalar(3.0f);
|
||||||
|
|
||||||
assertTrue(scalar.isScalar());
|
assertTrue(scalar.isScalar());
|
||||||
assertEquals(1, scalar.length());
|
assertEquals(1, scalar.length());
|
||||||
|
@ -5917,7 +5917,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
@Test
|
@Test
|
||||||
public void testVectorScalar_2() {
|
public void testVectorScalar_2() {
|
||||||
val vector = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5});
|
val vector = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5});
|
||||||
val scalar = Nd4j.trueScalar(2.0f);
|
val scalar = Nd4j.scalar(2.0f);
|
||||||
val exp = Nd4j.trueVector(new float[]{3, 4, 5, 6, 7});
|
val exp = Nd4j.trueVector(new float[]{3, 4, 5, 6, 7});
|
||||||
|
|
||||||
vector.addi(scalar);
|
vector.addi(scalar);
|
||||||
|
@ -5927,7 +5927,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReshapeScalar() {
|
public void testReshapeScalar() {
|
||||||
val scalar = Nd4j.trueScalar(2.0f);
|
val scalar = Nd4j.scalar(2.0f);
|
||||||
val newShape = scalar.reshape(1, 1, 1, 1);
|
val newShape = scalar.reshape(1, 1, 1, 1);
|
||||||
|
|
||||||
assertEquals(4, newShape.rank());
|
assertEquals(4, newShape.rank());
|
||||||
|
@ -5958,7 +5958,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test(expected = IllegalStateException.class)
|
@Test(expected = IllegalStateException.class)
|
||||||
public void testTranspose2() {
|
public void testTranspose2() {
|
||||||
val scalar = Nd4j.trueScalar(2.f);
|
val scalar = Nd4j.scalar(2.f);
|
||||||
|
|
||||||
assertArrayEquals(new long[]{}, scalar.shape());
|
assertArrayEquals(new long[]{}, scalar.shape());
|
||||||
assertArrayEquals(new long[]{}, scalar.stride());
|
assertArrayEquals(new long[]{}, scalar.stride());
|
||||||
|
@ -5991,8 +5991,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
@Test
|
@Test
|
||||||
public void testScalarSqueeze() {
|
public void testScalarSqueeze() {
|
||||||
val scalar = Nd4j.create(new float[]{2.0f}, new long[]{1, 1});
|
val scalar = Nd4j.create(new float[]{2.0f}, new long[]{1, 1});
|
||||||
val output = Nd4j.trueScalar(0.0f);
|
val output = Nd4j.scalar(0.0f);
|
||||||
val exp = Nd4j.trueScalar(2.0f);
|
val exp = Nd4j.scalar(2.0f);
|
||||||
val op = DynamicCustomOp.builder("squeeze")
|
val op = DynamicCustomOp.builder("squeeze")
|
||||||
.addInputs(scalar)
|
.addInputs(scalar)
|
||||||
.addOutputs(output)
|
.addOutputs(output)
|
||||||
|
@ -6012,8 +6012,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
assertArrayEquals(new long[]{1}, scalar.shape());
|
assertArrayEquals(new long[]{1}, scalar.shape());
|
||||||
|
|
||||||
val output = Nd4j.trueScalar(0.0f);
|
val output = Nd4j.scalar(0.0f);
|
||||||
val exp = Nd4j.trueScalar(2.0f);
|
val exp = Nd4j.scalar(2.0f);
|
||||||
val op = DynamicCustomOp.builder("squeeze")
|
val op = DynamicCustomOp.builder("squeeze")
|
||||||
.addInputs(scalar)
|
.addInputs(scalar)
|
||||||
.addOutputs(output)
|
.addOutputs(output)
|
||||||
|
@ -6113,7 +6113,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
@Test
|
@Test
|
||||||
public void testValueArrayOf_2() {
|
public void testValueArrayOf_2() {
|
||||||
val scalar = Nd4j.valueArrayOf(new long[] {}, 2f);
|
val scalar = Nd4j.valueArrayOf(new long[] {}, 2f);
|
||||||
val exp = Nd4j.trueScalar(2f);
|
val exp = Nd4j.scalar(2f);
|
||||||
|
|
||||||
assertArrayEquals(exp.shape(), scalar.shape());
|
assertArrayEquals(exp.shape(), scalar.shape());
|
||||||
assertEquals(exp, scalar);
|
assertEquals(exp, scalar);
|
||||||
|
@ -6873,7 +6873,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
val exp_2 = Nd4j.create(new double[]{0.0, 1.0, 2.0}, new long[]{3});
|
val exp_2 = Nd4j.create(new double[]{0.0, 1.0, 2.0}, new long[]{3});
|
||||||
val exp_3 = Nd4j.create(new double[]{1.0, 2.0, 3.0}, new long[]{3});
|
val exp_3 = Nd4j.create(new double[]{1.0, 2.0, 3.0}, new long[]{3});
|
||||||
val arrayX = Nd4j.create(new double[]{1.0, 2.0, 3.0}, new long[]{3});
|
val arrayX = Nd4j.create(new double[]{1.0, 2.0, 3.0}, new long[]{3});
|
||||||
val arrayY = Nd4j.trueScalar(1.0);
|
val arrayY = Nd4j.scalar(1.0);
|
||||||
|
|
||||||
val arrayZ_1 = arrayX.add(arrayY);
|
val arrayZ_1 = arrayX.add(arrayY);
|
||||||
assertEquals(exp_1, arrayZ_1);
|
assertEquals(exp_1, arrayZ_1);
|
||||||
|
|
|
@ -104,7 +104,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicCreation_5() {
|
public void testBasicCreation_5() {
|
||||||
val scalar = Nd4j.trueScalar(new Integer(1));
|
val scalar = Nd4j.scalar(new Integer(1));
|
||||||
assertNotNull(scalar);
|
assertNotNull(scalar);
|
||||||
assertEquals(0, scalar.rank());
|
assertEquals(0, scalar.rank());
|
||||||
assertEquals(1, scalar.length());
|
assertEquals(1, scalar.length());
|
||||||
|
@ -114,7 +114,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicCreation_6() {
|
public void testBasicCreation_6() {
|
||||||
val scalar = Nd4j.trueScalar(1);
|
val scalar = Nd4j.scalar(1);
|
||||||
assertNotNull(scalar);
|
assertNotNull(scalar);
|
||||||
assertEquals(0, scalar.rank());
|
assertEquals(0, scalar.rank());
|
||||||
assertEquals(1, scalar.length());
|
assertEquals(1, scalar.length());
|
||||||
|
@ -124,7 +124,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicCreation_7() {
|
public void testBasicCreation_7() {
|
||||||
val scalar = Nd4j.trueScalar(1L);
|
val scalar = Nd4j.scalar(1L);
|
||||||
assertNotNull(scalar);
|
assertNotNull(scalar);
|
||||||
assertEquals(0, scalar.rank());
|
assertEquals(0, scalar.rank());
|
||||||
assertEquals(1, scalar.length());
|
assertEquals(1, scalar.length());
|
||||||
|
|
|
@ -439,7 +439,7 @@ public class ShapeTestsC extends BaseNd4jTest {
|
||||||
@Test
|
@Test
|
||||||
public void testReshapeToTrueScalar_1() {
|
public void testReshapeToTrueScalar_1() {
|
||||||
val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1});
|
val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1});
|
||||||
val exp = Nd4j.trueScalar(1.0f);
|
val exp = Nd4j.scalar(1.0f);
|
||||||
|
|
||||||
assertArrayEquals(new long[]{1, 1}, orig.shape());
|
assertArrayEquals(new long[]{1, 1}, orig.shape());
|
||||||
|
|
||||||
|
@ -452,7 +452,7 @@ public class ShapeTestsC extends BaseNd4jTest {
|
||||||
@Test
|
@Test
|
||||||
public void testReshapeToTrueScalar_2() {
|
public void testReshapeToTrueScalar_2() {
|
||||||
val orig = Nd4j.create(new float[]{1.0f}, new int[]{1});
|
val orig = Nd4j.create(new float[]{1.0f}, new int[]{1});
|
||||||
val exp = Nd4j.trueScalar(1.0f);
|
val exp = Nd4j.scalar(1.0f);
|
||||||
|
|
||||||
assertArrayEquals(new long[]{1}, orig.shape());
|
assertArrayEquals(new long[]{1}, orig.shape());
|
||||||
|
|
||||||
|
@ -478,7 +478,7 @@ public class ShapeTestsC extends BaseNd4jTest {
|
||||||
@Test
|
@Test
|
||||||
public void testReshapeToTrueScalar_4() {
|
public void testReshapeToTrueScalar_4() {
|
||||||
val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1});
|
val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1});
|
||||||
val exp = Nd4j.trueScalar(1.0f);
|
val exp = Nd4j.scalar(1.0f);
|
||||||
|
|
||||||
assertArrayEquals(new long[]{1, 1}, orig.shape());
|
assertArrayEquals(new long[]{1, 1}, orig.shape());
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue