fix for eclipse#8087 (#129)

*  fix for #8087

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* remove commented code.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* removing trueScalar.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* remove tueScalar.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-08-20 15:20:40 +09:00 committed by GitHub
parent 10d676e0b8
commit 38310777ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 79 additions and 135 deletions

View File

@ -1280,62 +1280,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
return create(new double[] {value}, new int[0], new int[0], offset); return create(new double[] {value}, new int[0], new int[0], offset);
} }
@Override
public INDArray trueScalar(DataType dataType, Number value) {
val ws = Nd4j.getMemoryManager().getCurrentWorkspace();
switch (dataType) {
case DOUBLE:
return create(new double[] {value.doubleValue()}, new long[] {}, new long[] {}, dataType, ws);
case FLOAT:
return create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, dataType, ws);
case BFLOAT16:
return create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, dataType, ws);
case HALF:
return create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, dataType, ws);
case UINT32:
case INT:
return create(new int[] {value.intValue()}, new long[] {}, new long[] {}, dataType, ws);
case UINT64:
case LONG:
return create(new long[] {value.longValue()}, new long[] {}, new long[] {}, dataType, ws);
case UINT16:
case SHORT:
return create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, dataType, ws);
case BYTE:
return create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, dataType, ws);
case UBYTE:
return create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, dataType, ws);
case BOOL:
val b = value.byteValue();
val arr = create(new byte[] {b}, new long[] {}, new long[] {}, dataType, ws);
return arr;
default:
throw new UnsupportedOperationException("Unsupported data type used: " + dataType);
}
}
@Override
public INDArray trueScalar(Number value) {
val ws = Nd4j.getMemoryManager().getCurrentWorkspace();
if (value instanceof Double || value instanceof AtomicDouble) /* note that org.nd4j.linalg.primitives.AtomicDouble extends com.google.common.util.concurrent.AtomicDouble */
return create(new double[] {value.doubleValue()}, new long[] {}, new long[] {}, DataType.DOUBLE, ws);
else if (value instanceof Float)
return create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, DataType.FLOAT, ws);
else if (value instanceof Long || value instanceof AtomicLong)
return create(new long[] {value.longValue()}, new long[] {}, new long[] {}, DataType.LONG, ws);
else if (value instanceof Integer || value instanceof AtomicInteger)
return create(new int[] {value.intValue()}, new long[] {}, new long[] {}, DataType.INT, ws);
else if (value instanceof Short)
return create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, DataType.SHORT, ws);
else if (value instanceof Byte)
return create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, DataType.BYTE, ws);
else
throw new UnsupportedOperationException("Unsupported data type: [" + value.getClass().getSimpleName() + "]");
}
public INDArray trueVector(boolean[] data) { public INDArray trueVector(boolean[] data) {
return create(data, new long[] {data.length}, new long[]{1}, DataType.BOOL, Nd4j.getMemoryManager().getCurrentWorkspace()); return create(data, new long[] {data.length}, new long[]{1}, DataType.BOOL, Nd4j.getMemoryManager().getCurrentWorkspace());
} }
@ -1364,8 +1308,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
return create(data, new long[] {data.length}, new long[]{1}, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); return create(data, new long[] {data.length}, new long[]{1}, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
} }
/** /**
* Create a scalar nd array with the specified value and offset * Create a scalar nd array with the specified value and offset
* *

View File

@ -990,12 +990,6 @@ public interface NDArrayFactory {
INDArray empty(DataType type); INDArray empty(DataType type);
@Deprecated
INDArray trueScalar(Number value);
@Deprecated
INDArray trueScalar(DataType dataType, Number value);
@Deprecated @Deprecated
INDArray trueVector(boolean[] data); INDArray trueVector(boolean[] data);
@Deprecated @Deprecated

View File

@ -3736,19 +3736,6 @@ public class Nd4j {
return INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace()); return INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace());
} }
/**
* This method creates new 0D INDArray, aka scalar.
*
* PLEASE NOTE: Temporary method, added to ensure backward compatibility
* @param scalar data for INDArray.
* @return new INDArray
* * @deprecated Use Nd4j.scalar methods, such as {@link #scalar(double)} or {@link #scalar(DataType, Number)}
*/
@Deprecated
public static INDArray trueScalar(Number scalar) {
return INSTANCE.trueScalar(scalar);
}
/** /**
* @deprecated Use {@link #createFromArray(boolean...)} * @deprecated Use {@link #createFromArray(boolean...)}
*/ */
@ -5087,9 +5074,35 @@ public class Nd4j {
* @param value the value to initialize the scalar with * @param value the value to initialize the scalar with
* @return the created ndarray * @return the created ndarray
*/ */
@SuppressWarnings("deprecation")
public static INDArray scalar(DataType dataType, Number value) { public static INDArray scalar(DataType dataType, Number value) {
return INSTANCE.trueScalar(dataType, value); val ws = Nd4j.getMemoryManager().getCurrentWorkspace();
switch (dataType) {
case DOUBLE:
return INSTANCE.create(new double[] {value.doubleValue()}, new long[] {}, new long[] {}, dataType, ws);
case FLOAT:
case BFLOAT16:
case HALF:
return INSTANCE.create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, dataType, ws);
case UINT32:
case INT:
return INSTANCE.create(new int[] {value.intValue()}, new long[] {}, new long[] {}, dataType, ws);
case UINT64:
case LONG:
return INSTANCE.create(new long[] {value.longValue()}, new long[] {}, new long[] {}, dataType, ws);
case UINT16:
case SHORT:
return INSTANCE.create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, dataType, ws);
case BYTE:
return INSTANCE.create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, dataType, ws);
case UBYTE:
return INSTANCE.create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, dataType, ws);
case BOOL:
return INSTANCE.create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, dataType, ws);
default:
throw new UnsupportedOperationException("Unsupported data type used: " + dataType);
}
} }
/** /**

View File

@ -251,11 +251,6 @@ public class CpuSparseNDArrayFactory extends BaseSparseNDArrayFactory {
return null; return null;
} }
@Override
public INDArray trueScalar(Number value) {
throw new UnsupportedOperationException();
}
@Override @Override
public INDArray create(float[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) { public INDArray create(float[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null; return null;

View File

@ -504,7 +504,7 @@ public class MiscOpValidation extends BaseOpValidation {
double exp = Nd4j.diag(in).sumNumber().doubleValue(); double exp = Nd4j.diag(in).sumNumber().doubleValue();
TestCase tc = new TestCase(sd) TestCase tc = new TestCase(sd)
.expected(trace, Nd4j.trueScalar(exp)) .expected(trace, Nd4j.scalar(exp))
.testName(Arrays.toString(inShape)); .testName(Arrays.toString(inShape));
String err = OpValidation.validate(tc); String err = OpValidation.validate(tc);
@ -1296,7 +1296,7 @@ public class MiscOpValidation extends BaseOpValidation {
if(shape.length > 0){ if(shape.length > 0){
arr = Nd4j.rand(shape); arr = Nd4j.rand(shape);
} else { } else {
arr = Nd4j.trueScalar(Nd4j.rand(new int[]{1,1}).getDouble(0)); arr = Nd4j.scalar(Nd4j.rand(new int[]{1,1}).getDouble(0));
} }
SDVariable var = sd.var("in", arr); SDVariable var = sd.var("in", arr);
SDVariable xLike; SDVariable xLike;
@ -1388,7 +1388,7 @@ public class MiscOpValidation extends BaseOpValidation {
INDArray inArr; INDArray inArr;
if (shape == null) { if (shape == null) {
inArr = Nd4j.trueScalar(1.0); inArr = Nd4j.scalar(1.0);
} else { } else {
inArr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(shape); inArr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(shape);
} }

View File

@ -79,7 +79,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) { if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5); dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else { } else {
dLdOut = Nd4j.trueScalar(0.5); dLdOut = Nd4j.scalar(0.5);
} }
INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), 0.5); INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), 0.5);
INDArray dLdIn = Nd4j.createUninitialized(3, 4); INDArray dLdIn = Nd4j.createUninitialized(3, 4);
@ -164,7 +164,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) { if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5); dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else { } else {
dLdOut = Nd4j.trueScalar(0.5); dLdOut = Nd4j.scalar(0.5);
} }
INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), 0.5 / preReduceInput.length()); INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), 0.5 / preReduceInput.length());
INDArray dLdIn = Nd4j.createUninitialized(3, 4); INDArray dLdIn = Nd4j.createUninitialized(3, 4);
@ -178,7 +178,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test @Test
public void testMeanBP_Rank1() { public void testMeanBP_Rank1() {
INDArray dLdOut = Nd4j.trueScalar(0.5); INDArray dLdOut = Nd4j.scalar(0.5);
INDArray preReduceInput = Nd4j.create(new double[]{2,3,4}, new long[]{3}); INDArray preReduceInput = Nd4j.create(new double[]{2,3,4}, new long[]{3});
INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5/3); INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5/3);
@ -261,7 +261,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) { if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5); dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else { } else {
dLdOut = Nd4j.trueScalar(0.5); dLdOut = Nd4j.scalar(0.5);
} }
INDArray dLdInExpected = Nd4j.zeros(preReduceInput.shape()); INDArray dLdInExpected = Nd4j.zeros(preReduceInput.shape());
dLdInExpected.putScalar(new int[]{2, 2}, 0.5); //Minimum value: position at [2,2] dLdInExpected.putScalar(new int[]{2, 2}, 0.5); //Minimum value: position at [2,2]
@ -343,7 +343,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) { if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5); dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else { } else {
dLdOut = Nd4j.trueScalar(0.5); dLdOut = Nd4j.scalar(0.5);
} }
INDArray dLdInExpected = Nd4j.zeros(preReduceInput.shape()); INDArray dLdInExpected = Nd4j.zeros(preReduceInput.shape());
dLdInExpected.putScalar(new int[]{2, 2}, 0.5); //Maximum value: position at [2,2] dLdInExpected.putScalar(new int[]{2, 2}, 0.5); //Maximum value: position at [2,2]
@ -415,7 +415,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) { if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5); dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else { } else {
dLdOut = Nd4j.trueScalar(0.5); dLdOut = Nd4j.scalar(0.5);
} }
double prod = preReduceInput.prodNumber().doubleValue(); double prod = preReduceInput.prodNumber().doubleValue();
INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), prod).divi(preReduceInput).muli(0.5); INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), prod).divi(preReduceInput).muli(0.5);
@ -500,7 +500,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) { if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5); dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else { } else {
dLdOut = Nd4j.trueScalar(0.5); dLdOut = Nd4j.scalar(0.5);
} }
double stdev = preReduceInput.stdNumber(biasCorrected).doubleValue(); double stdev = preReduceInput.stdNumber(biasCorrected).doubleValue();
@ -523,7 +523,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test @Test
public void testStdevBP_Rank1() { public void testStdevBP_Rank1() {
INDArray dLdOut = Nd4j.trueScalar(0.5); INDArray dLdOut = Nd4j.scalar(0.5);
INDArray preReduceInput = Nd4j.create(new double[]{2,3,4}, new long[]{3}); INDArray preReduceInput = Nd4j.create(new double[]{2,3,4}, new long[]{3});
double stdev = preReduceInput.stdNumber(true).doubleValue(); double stdev = preReduceInput.stdNumber(true).doubleValue();
double mean = preReduceInput.meanNumber().doubleValue(); double mean = preReduceInput.meanNumber().doubleValue();
@ -602,7 +602,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) { if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5); dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else { } else {
dLdOut = Nd4j.trueScalar(0.5); dLdOut = Nd4j.scalar(0.5);
} }
double var = preReduceInput.var(biasCorrected).getDouble(0); double var = preReduceInput.var(biasCorrected).getDouble(0);
@ -811,7 +811,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) { if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5); dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else { } else {
dLdOut = Nd4j.trueScalar(0.5); dLdOut = Nd4j.scalar(0.5);
} }
INDArray dLdInExpected = sgn.muli(0.5); INDArray dLdInExpected = sgn.muli(0.5);
INDArray dLdIn = Nd4j.createUninitialized(3, 4); INDArray dLdIn = Nd4j.createUninitialized(3, 4);
@ -873,7 +873,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) { if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5); dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else { } else {
dLdOut = Nd4j.trueScalar(0.5); dLdOut = Nd4j.scalar(0.5);
} }
INDArray dLdInExpected = sgn.mul(max).mul(0.5); INDArray dLdInExpected = sgn.mul(max).mul(0.5);
INDArray dLdIn = Nd4j.createUninitialized(3, 4); INDArray dLdIn = Nd4j.createUninitialized(3, 4);

View File

@ -234,12 +234,12 @@ public class ReductionOpValidation extends BaseOpValidation {
case 10: case 10:
loss = sd.math().countNonZero("loss", input); loss = sd.math().countNonZero("loss", input);
name = "countNonZero"; name = "countNonZero";
tc.expectedOutput("loss", Nd4j.trueScalar(inputArr.length())); tc.expectedOutput("loss", Nd4j.scalar(inputArr.length()));
break; break;
case 11: case 11:
loss = sd.math().countZero("loss", input); loss = sd.math().countZero("loss", input);
name = "countZero"; name = "countZero";
tc.expectedOutput("loss", Nd4j.trueScalar(0)); tc.expectedOutput("loss", Nd4j.scalar(0));
break; break;
case 12: case 12:
loss = sd.math().amax("loss", input); loss = sd.math().amax("loss", input);
@ -280,21 +280,21 @@ public class ReductionOpValidation extends BaseOpValidation {
name = "sqnorm"; name = "sqnorm";
loss = sd.squaredNorm("loss", input); loss = sd.squaredNorm("loss", input);
double norm2 = inputArr.norm2Number().doubleValue(); double norm2 = inputArr.norm2Number().doubleValue();
tc.expected("loss", Nd4j.trueScalar(norm2 * norm2)); tc.expected("loss", Nd4j.scalar(norm2 * norm2));
break; break;
case 19: case 19:
inputArr = Nd4j.rand(minibatch, nOut); inputArr = Nd4j.rand(minibatch, nOut);
name = "logEntropy"; name = "logEntropy";
loss = sd.math().logEntropy("loss", input); loss = sd.math().logEntropy("loss", input);
double logEntropy = inputArr.logEntropyNumber().doubleValue(); double logEntropy = inputArr.logEntropyNumber().doubleValue();
tc.expected(loss, Nd4j.trueScalar(logEntropy)); tc.expected(loss, Nd4j.scalar(logEntropy));
break; break;
case 20: case 20:
inputArr = Nd4j.rand(minibatch, nOut); inputArr = Nd4j.rand(minibatch, nOut);
name = "shannonEntropy"; name = "shannonEntropy";
loss = sd.math().shannonEntropy("loss", input); loss = sd.math().shannonEntropy("loss", input);
double shannonEntropy = inputArr.shannonEntropyNumber().doubleValue(); double shannonEntropy = inputArr.shannonEntropyNumber().doubleValue();
tc.expected(loss, Nd4j.trueScalar(shannonEntropy)); tc.expected(loss, Nd4j.scalar(shannonEntropy));
if (OpValidationSuite.IGNORE_FAILING) { if (OpValidationSuite.IGNORE_FAILING) {
continue; continue;
} }

View File

@ -1030,7 +1030,7 @@ public class ShapeOpValidation extends BaseOpValidation {
//Case 1: shape is provided + scalar //Case 1: shape is provided + scalar
sd = SameDiff.create(); sd = SameDiff.create();
ia = Nd4j.trueScalar(3.0); ia = Nd4j.scalar(3.0);
in = sd.var(ia); in = sd.var(ia);
constant = sd.constant(in, 3,4,5); constant = sd.constant(in, 3,4,5);
INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0); INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0);
@ -1169,7 +1169,7 @@ public class ShapeOpValidation extends BaseOpValidation {
double d = new LUDecomposition(CheckUtil.convertToApacheMatrix(in)).getDeterminant(); double d = new LUDecomposition(CheckUtil.convertToApacheMatrix(in)).getDeterminant();
INDArray outExp = Nd4j.trueScalar(d); INDArray outExp = Nd4j.scalar(d);
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
.expected(md.getVarName(), outExp)); .expected(md.getVarName(), outExp));
@ -1193,7 +1193,7 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(d, d2, 1e-5); assertEquals(d, d2, 1e-5);
INDArray outExp = Nd4j.trueScalar(d); INDArray outExp = Nd4j.scalar(d);
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
.expected(md.getVarName(), outExp)); .expected(md.getVarName(), outExp));
@ -1224,7 +1224,7 @@ public class ShapeOpValidation extends BaseOpValidation {
- a[0][0] * a[1][2] * a[2][1]; - a[0][0] * a[1][2] * a[2][1];
assertEquals(d, d2, 1e-6); //Manual calc and Apache commons both match: 0.03589524995561552 assertEquals(d, d2, 1e-6); //Manual calc and Apache commons both match: 0.03589524995561552
INDArray outExp = Nd4j.trueScalar(d); INDArray outExp = Nd4j.scalar(d);
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
.expected(md.getVarName(), outExp)); .expected(md.getVarName(), outExp));
@ -1247,7 +1247,7 @@ public class ShapeOpValidation extends BaseOpValidation {
//System.out.println(d); //System.out.println(d);
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
.expected(md.getVarName(), Nd4j.trueScalar(d))); .expected(md.getVarName(), Nd4j.scalar(d)));
assertNull(err); assertNull(err);
} }
@ -1792,7 +1792,7 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testSplit1(){ public void testSplit1(){
INDArray in = Nd4j.linspace(1,10,10).reshape(10); INDArray in = Nd4j.linspace(1,10,10).reshape(10);
INDArray axis = Nd4j.trueScalar(-1); INDArray axis = Nd4j.scalar(-1);
INDArray out1 = Nd4j.create(new long[]{5}); INDArray out1 = Nd4j.create(new long[]{5});
INDArray out2 = Nd4j.create(new long[]{5}); INDArray out2 = Nd4j.create(new long[]{5});

View File

@ -1408,7 +1408,7 @@ public class TransformOpValidation extends BaseOpValidation {
public void testScatterOpsScalar(){ public void testScatterOpsScalar(){
for(String s : new String[]{"add", "sub", "mul", "div"}) { for(String s : new String[]{"add", "sub", "mul", "div"}) {
INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3); INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3);
INDArray indices = Nd4j.trueScalar(5); INDArray indices = Nd4j.scalar(5);
INDArray upd = Nd4j.create(new double[]{10, 20, 30}); INDArray upd = Nd4j.create(new double[]{10, 20, 30});
//The non-scalar case works: //The non-scalar case works:
@ -1452,7 +1452,7 @@ public class TransformOpValidation extends BaseOpValidation {
public void testPad(){ public void testPad(){
INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0); INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0);
INDArray pad = Nd4j.create(new double[]{1,1}, new long[]{1,2}).castTo(DataType.LONG); INDArray pad = Nd4j.create(new double[]{1,1}, new long[]{1,2}).castTo(DataType.LONG);
INDArray value = Nd4j.trueScalar(10.0); INDArray value = Nd4j.scalar(10.0);
INDArray out = Nd4j.create(new long[]{7}); INDArray out = Nd4j.create(new long[]{7});

View File

@ -115,7 +115,7 @@ public class ByteOrderTests extends BaseNd4jTest {
@Test @Test
public void testScalarEncoding() { public void testScalarEncoding() {
val scalar = Nd4j.trueScalar(2.0f); val scalar = Nd4j.scalar(2.0f);
FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0); FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0);
val fb = scalar.toFlatArray(bufferBuilder); val fb = scalar.toFlatArray(bufferBuilder);

View File

@ -797,7 +797,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
Nd4j.create(1); Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
assertNotNull(tg); assertNotNull(tg);
val input = Nd4j.trueScalar(4.0); val input = Nd4j.scalar(4.0);
tg.associateArrayWithVariable(input, tg.getVariable("input_1")); tg.associateArrayWithVariable(input, tg.getVariable("input_1"));
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_4.fb")); tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_4.fb"));
@ -815,7 +815,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
Nd4j.create(1); Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
assertNotNull(tg); assertNotNull(tg);
val input = Nd4j.trueScalar(9.0); val input = Nd4j.scalar(9.0);
tg.associateArrayWithVariable(input, tg.getVariable("input_1")); tg.associateArrayWithVariable(input, tg.getVariable("input_1"));
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0.fb")); //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0.fb"));
@ -835,7 +835,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream()); val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
assertNotNull(tg); assertNotNull(tg);
val input0 = Nd4j.create(2, 2).assign(-4.0); val input0 = Nd4j.create(2, 2).assign(-4.0);
val input1 = Nd4j.trueScalar(1.0); val input1 = Nd4j.scalar(1.0);
tg.associateArrayWithVariable(input0, tg.getVariable("input_0")); tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
tg.associateArrayWithVariable(input1, tg.getVariable("input_1")); tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
@ -855,7 +855,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream()); val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
assertNotNull(tg); assertNotNull(tg);
val input0 = Nd4j.create(2, 2).assign(-9.0); val input0 = Nd4j.create(2, 2).assign(-9.0);
val input1 = Nd4j.trueScalar(1.0); val input1 = Nd4j.scalar(1.0);
tg.associateArrayWithVariable(input0, tg.getVariable("input_0")); tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
tg.associateArrayWithVariable(input1, tg.getVariable("input_1")); tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
@ -964,7 +964,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
assertNotNull(tg); assertNotNull(tg);
val input0 = Nd4j.create(new float[] {1, 2, 3, 4}, new int[] {2, 2}); val input0 = Nd4j.create(new float[] {1, 2, 3, 4}, new int[] {2, 2});
val input1 = Nd4j.trueScalar(11f); val input1 = Nd4j.scalar(11f);
tg.associateArrayWithVariable(input0, tg.getVariable("input_0")); tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
tg.associateArrayWithVariable(input1, tg.getVariable("input_1")); tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

View File

@ -5880,9 +5880,9 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test @Test
public void testScalar_2() { public void testScalar_2() {
val scalar = Nd4j.trueScalar(2.0f); val scalar = Nd4j.scalar(2.0f);
val scalar2 = Nd4j.trueScalar(2.0f); val scalar2 = Nd4j.scalar(2.0f);
val scalar3 = Nd4j.trueScalar(3.0f); val scalar3 = Nd4j.scalar(3.0f);
assertTrue(scalar.isScalar()); assertTrue(scalar.isScalar());
assertEquals(1, scalar.length()); assertEquals(1, scalar.length());
@ -5917,7 +5917,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test @Test
public void testVectorScalar_2() { public void testVectorScalar_2() {
val vector = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5}); val vector = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5});
val scalar = Nd4j.trueScalar(2.0f); val scalar = Nd4j.scalar(2.0f);
val exp = Nd4j.trueVector(new float[]{3, 4, 5, 6, 7}); val exp = Nd4j.trueVector(new float[]{3, 4, 5, 6, 7});
vector.addi(scalar); vector.addi(scalar);
@ -5927,7 +5927,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test @Test
public void testReshapeScalar() { public void testReshapeScalar() {
val scalar = Nd4j.trueScalar(2.0f); val scalar = Nd4j.scalar(2.0f);
val newShape = scalar.reshape(1, 1, 1, 1); val newShape = scalar.reshape(1, 1, 1, 1);
assertEquals(4, newShape.rank()); assertEquals(4, newShape.rank());
@ -5958,7 +5958,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test(expected = IllegalStateException.class) @Test(expected = IllegalStateException.class)
public void testTranspose2() { public void testTranspose2() {
val scalar = Nd4j.trueScalar(2.f); val scalar = Nd4j.scalar(2.f);
assertArrayEquals(new long[]{}, scalar.shape()); assertArrayEquals(new long[]{}, scalar.shape());
assertArrayEquals(new long[]{}, scalar.stride()); assertArrayEquals(new long[]{}, scalar.stride());
@ -5991,8 +5991,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test @Test
public void testScalarSqueeze() { public void testScalarSqueeze() {
val scalar = Nd4j.create(new float[]{2.0f}, new long[]{1, 1}); val scalar = Nd4j.create(new float[]{2.0f}, new long[]{1, 1});
val output = Nd4j.trueScalar(0.0f); val output = Nd4j.scalar(0.0f);
val exp = Nd4j.trueScalar(2.0f); val exp = Nd4j.scalar(2.0f);
val op = DynamicCustomOp.builder("squeeze") val op = DynamicCustomOp.builder("squeeze")
.addInputs(scalar) .addInputs(scalar)
.addOutputs(output) .addOutputs(output)
@ -6012,8 +6012,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertArrayEquals(new long[]{1}, scalar.shape()); assertArrayEquals(new long[]{1}, scalar.shape());
val output = Nd4j.trueScalar(0.0f); val output = Nd4j.scalar(0.0f);
val exp = Nd4j.trueScalar(2.0f); val exp = Nd4j.scalar(2.0f);
val op = DynamicCustomOp.builder("squeeze") val op = DynamicCustomOp.builder("squeeze")
.addInputs(scalar) .addInputs(scalar)
.addOutputs(output) .addOutputs(output)
@ -6113,7 +6113,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test @Test
public void testValueArrayOf_2() { public void testValueArrayOf_2() {
val scalar = Nd4j.valueArrayOf(new long[] {}, 2f); val scalar = Nd4j.valueArrayOf(new long[] {}, 2f);
val exp = Nd4j.trueScalar(2f); val exp = Nd4j.scalar(2f);
assertArrayEquals(exp.shape(), scalar.shape()); assertArrayEquals(exp.shape(), scalar.shape());
assertEquals(exp, scalar); assertEquals(exp, scalar);
@ -6873,7 +6873,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
val exp_2 = Nd4j.create(new double[]{0.0, 1.0, 2.0}, new long[]{3}); val exp_2 = Nd4j.create(new double[]{0.0, 1.0, 2.0}, new long[]{3});
val exp_3 = Nd4j.create(new double[]{1.0, 2.0, 3.0}, new long[]{3}); val exp_3 = Nd4j.create(new double[]{1.0, 2.0, 3.0}, new long[]{3});
val arrayX = Nd4j.create(new double[]{1.0, 2.0, 3.0}, new long[]{3}); val arrayX = Nd4j.create(new double[]{1.0, 2.0, 3.0}, new long[]{3});
val arrayY = Nd4j.trueScalar(1.0); val arrayY = Nd4j.scalar(1.0);
val arrayZ_1 = arrayX.add(arrayY); val arrayZ_1 = arrayX.add(arrayY);
assertEquals(exp_1, arrayZ_1); assertEquals(exp_1, arrayZ_1);

View File

@ -104,7 +104,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
@Test @Test
public void testBasicCreation_5() { public void testBasicCreation_5() {
val scalar = Nd4j.trueScalar(new Integer(1)); val scalar = Nd4j.scalar(new Integer(1));
assertNotNull(scalar); assertNotNull(scalar);
assertEquals(0, scalar.rank()); assertEquals(0, scalar.rank());
assertEquals(1, scalar.length()); assertEquals(1, scalar.length());
@ -114,7 +114,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
@Test @Test
public void testBasicCreation_6() { public void testBasicCreation_6() {
val scalar = Nd4j.trueScalar(1); val scalar = Nd4j.scalar(1);
assertNotNull(scalar); assertNotNull(scalar);
assertEquals(0, scalar.rank()); assertEquals(0, scalar.rank());
assertEquals(1, scalar.length()); assertEquals(1, scalar.length());
@ -124,7 +124,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
@Test @Test
public void testBasicCreation_7() { public void testBasicCreation_7() {
val scalar = Nd4j.trueScalar(1L); val scalar = Nd4j.scalar(1L);
assertNotNull(scalar); assertNotNull(scalar);
assertEquals(0, scalar.rank()); assertEquals(0, scalar.rank());
assertEquals(1, scalar.length()); assertEquals(1, scalar.length());

View File

@ -439,7 +439,7 @@ public class ShapeTestsC extends BaseNd4jTest {
@Test @Test
public void testReshapeToTrueScalar_1() { public void testReshapeToTrueScalar_1() {
val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1}); val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1});
val exp = Nd4j.trueScalar(1.0f); val exp = Nd4j.scalar(1.0f);
assertArrayEquals(new long[]{1, 1}, orig.shape()); assertArrayEquals(new long[]{1, 1}, orig.shape());
@ -452,7 +452,7 @@ public class ShapeTestsC extends BaseNd4jTest {
@Test @Test
public void testReshapeToTrueScalar_2() { public void testReshapeToTrueScalar_2() {
val orig = Nd4j.create(new float[]{1.0f}, new int[]{1}); val orig = Nd4j.create(new float[]{1.0f}, new int[]{1});
val exp = Nd4j.trueScalar(1.0f); val exp = Nd4j.scalar(1.0f);
assertArrayEquals(new long[]{1}, orig.shape()); assertArrayEquals(new long[]{1}, orig.shape());
@ -478,7 +478,7 @@ public class ShapeTestsC extends BaseNd4jTest {
@Test @Test
public void testReshapeToTrueScalar_4() { public void testReshapeToTrueScalar_4() {
val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1}); val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1});
val exp = Nd4j.trueScalar(1.0f); val exp = Nd4j.scalar(1.0f);
assertArrayEquals(new long[]{1, 1}, orig.shape()); assertArrayEquals(new long[]{1, 1}, orig.shape());