fix for eclipse#8087 (#129)

*  fix for #8087

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* remove commented code.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* removing trueScalar.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* remove tueScalar.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-08-20 15:20:40 +09:00 committed by GitHub
parent 10d676e0b8
commit 38310777ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 79 additions and 135 deletions

View File

@ -1280,62 +1280,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
return create(new double[] {value}, new int[0], new int[0], offset);
}
@Override
public INDArray trueScalar(DataType dataType, Number value) {
val ws = Nd4j.getMemoryManager().getCurrentWorkspace();
switch (dataType) {
case DOUBLE:
return create(new double[] {value.doubleValue()}, new long[] {}, new long[] {}, dataType, ws);
case FLOAT:
return create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, dataType, ws);
case BFLOAT16:
return create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, dataType, ws);
case HALF:
return create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, dataType, ws);
case UINT32:
case INT:
return create(new int[] {value.intValue()}, new long[] {}, new long[] {}, dataType, ws);
case UINT64:
case LONG:
return create(new long[] {value.longValue()}, new long[] {}, new long[] {}, dataType, ws);
case UINT16:
case SHORT:
return create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, dataType, ws);
case BYTE:
return create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, dataType, ws);
case UBYTE:
return create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, dataType, ws);
case BOOL:
val b = value.byteValue();
val arr = create(new byte[] {b}, new long[] {}, new long[] {}, dataType, ws);
return arr;
default:
throw new UnsupportedOperationException("Unsupported data type used: " + dataType);
}
}
@Override
public INDArray trueScalar(Number value) {
val ws = Nd4j.getMemoryManager().getCurrentWorkspace();
if (value instanceof Double || value instanceof AtomicDouble) /* note that org.nd4j.linalg.primitives.AtomicDouble extends com.google.common.util.concurrent.AtomicDouble */
return create(new double[] {value.doubleValue()}, new long[] {}, new long[] {}, DataType.DOUBLE, ws);
else if (value instanceof Float)
return create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, DataType.FLOAT, ws);
else if (value instanceof Long || value instanceof AtomicLong)
return create(new long[] {value.longValue()}, new long[] {}, new long[] {}, DataType.LONG, ws);
else if (value instanceof Integer || value instanceof AtomicInteger)
return create(new int[] {value.intValue()}, new long[] {}, new long[] {}, DataType.INT, ws);
else if (value instanceof Short)
return create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, DataType.SHORT, ws);
else if (value instanceof Byte)
return create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, DataType.BYTE, ws);
else
throw new UnsupportedOperationException("Unsupported data type: [" + value.getClass().getSimpleName() + "]");
}
public INDArray trueVector(boolean[] data) {
return create(data, new long[] {data.length}, new long[]{1}, DataType.BOOL, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@ -1364,8 +1308,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
return create(data, new long[] {data.length}, new long[]{1}, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
}
/**
* Create a scalar nd array with the specified value and offset
*

View File

@ -990,12 +990,6 @@ public interface NDArrayFactory {
INDArray empty(DataType type);
@Deprecated
INDArray trueScalar(Number value);
@Deprecated
INDArray trueScalar(DataType dataType, Number value);
@Deprecated
INDArray trueVector(boolean[] data);
@Deprecated

View File

@ -3736,19 +3736,6 @@ public class Nd4j {
return INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
/**
* This method creates new 0D INDArray, aka scalar.
*
* PLEASE NOTE: Temporary method, added to ensure backward compatibility
* @param scalar data for INDArray.
* @return new INDArray
* * @deprecated Use Nd4j.scalar methods, such as {@link #scalar(double)} or {@link #scalar(DataType, Number)}
*/
@Deprecated
public static INDArray trueScalar(Number scalar) {
return INSTANCE.trueScalar(scalar);
}
/**
* @deprecated Use {@link #createFromArray(boolean...)}
*/
@ -5087,9 +5074,35 @@ public class Nd4j {
* @param value the value to initialize the scalar with
* @return the created ndarray
*/
@SuppressWarnings("deprecation")
public static INDArray scalar(DataType dataType, Number value) {
return INSTANCE.trueScalar(dataType, value);
val ws = Nd4j.getMemoryManager().getCurrentWorkspace();
switch (dataType) {
case DOUBLE:
return INSTANCE.create(new double[] {value.doubleValue()}, new long[] {}, new long[] {}, dataType, ws);
case FLOAT:
case BFLOAT16:
case HALF:
return INSTANCE.create(new float[] {value.floatValue()}, new long[] {}, new long[] {}, dataType, ws);
case UINT32:
case INT:
return INSTANCE.create(new int[] {value.intValue()}, new long[] {}, new long[] {}, dataType, ws);
case UINT64:
case LONG:
return INSTANCE.create(new long[] {value.longValue()}, new long[] {}, new long[] {}, dataType, ws);
case UINT16:
case SHORT:
return INSTANCE.create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, dataType, ws);
case BYTE:
return INSTANCE.create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, dataType, ws);
case UBYTE:
return INSTANCE.create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, dataType, ws);
case BOOL:
return INSTANCE.create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, dataType, ws);
default:
throw new UnsupportedOperationException("Unsupported data type used: " + dataType);
}
}
/**

View File

@ -251,11 +251,6 @@ public class CpuSparseNDArrayFactory extends BaseSparseNDArrayFactory {
return null;
}
@Override
public INDArray trueScalar(Number value) {
throw new UnsupportedOperationException();
}
@Override
public INDArray create(float[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;

View File

@ -504,7 +504,7 @@ public class MiscOpValidation extends BaseOpValidation {
double exp = Nd4j.diag(in).sumNumber().doubleValue();
TestCase tc = new TestCase(sd)
.expected(trace, Nd4j.trueScalar(exp))
.expected(trace, Nd4j.scalar(exp))
.testName(Arrays.toString(inShape));
String err = OpValidation.validate(tc);
@ -1296,7 +1296,7 @@ public class MiscOpValidation extends BaseOpValidation {
if(shape.length > 0){
arr = Nd4j.rand(shape);
} else {
arr = Nd4j.trueScalar(Nd4j.rand(new int[]{1,1}).getDouble(0));
arr = Nd4j.scalar(Nd4j.rand(new int[]{1,1}).getDouble(0));
}
SDVariable var = sd.var("in", arr);
SDVariable xLike;
@ -1388,7 +1388,7 @@ public class MiscOpValidation extends BaseOpValidation {
INDArray inArr;
if (shape == null) {
inArr = Nd4j.trueScalar(1.0);
inArr = Nd4j.scalar(1.0);
} else {
inArr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(shape);
}

View File

@ -79,7 +79,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else {
dLdOut = Nd4j.trueScalar(0.5);
dLdOut = Nd4j.scalar(0.5);
}
INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), 0.5);
INDArray dLdIn = Nd4j.createUninitialized(3, 4);
@ -164,7 +164,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else {
dLdOut = Nd4j.trueScalar(0.5);
dLdOut = Nd4j.scalar(0.5);
}
INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), 0.5 / preReduceInput.length());
INDArray dLdIn = Nd4j.createUninitialized(3, 4);
@ -178,7 +178,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test
public void testMeanBP_Rank1() {
INDArray dLdOut = Nd4j.trueScalar(0.5);
INDArray dLdOut = Nd4j.scalar(0.5);
INDArray preReduceInput = Nd4j.create(new double[]{2,3,4}, new long[]{3});
INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5/3);
@ -261,7 +261,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else {
dLdOut = Nd4j.trueScalar(0.5);
dLdOut = Nd4j.scalar(0.5);
}
INDArray dLdInExpected = Nd4j.zeros(preReduceInput.shape());
dLdInExpected.putScalar(new int[]{2, 2}, 0.5); //Minimum value: position at [2,2]
@ -343,7 +343,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else {
dLdOut = Nd4j.trueScalar(0.5);
dLdOut = Nd4j.scalar(0.5);
}
INDArray dLdInExpected = Nd4j.zeros(preReduceInput.shape());
dLdInExpected.putScalar(new int[]{2, 2}, 0.5); //Maximum value: position at [2,2]
@ -415,7 +415,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else {
dLdOut = Nd4j.trueScalar(0.5);
dLdOut = Nd4j.scalar(0.5);
}
double prod = preReduceInput.prodNumber().doubleValue();
INDArray dLdInExpected = Nd4j.valueArrayOf(preReduceInput.shape(), prod).divi(preReduceInput).muli(0.5);
@ -500,7 +500,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else {
dLdOut = Nd4j.trueScalar(0.5);
dLdOut = Nd4j.scalar(0.5);
}
double stdev = preReduceInput.stdNumber(biasCorrected).doubleValue();
@ -523,7 +523,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test
public void testStdevBP_Rank1() {
INDArray dLdOut = Nd4j.trueScalar(0.5);
INDArray dLdOut = Nd4j.scalar(0.5);
INDArray preReduceInput = Nd4j.create(new double[]{2,3,4}, new long[]{3});
double stdev = preReduceInput.stdNumber(true).doubleValue();
double mean = preReduceInput.meanNumber().doubleValue();
@ -602,7 +602,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else {
dLdOut = Nd4j.trueScalar(0.5);
dLdOut = Nd4j.scalar(0.5);
}
double var = preReduceInput.var(biasCorrected).getDouble(0);
@ -811,7 +811,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else {
dLdOut = Nd4j.trueScalar(0.5);
dLdOut = Nd4j.scalar(0.5);
}
INDArray dLdInExpected = sgn.muli(0.5);
INDArray dLdIn = Nd4j.createUninitialized(3, 4);
@ -873,7 +873,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
if (keepDims) {
dLdOut = Nd4j.valueArrayOf(new long[]{1, 1}, 0.5);
} else {
dLdOut = Nd4j.trueScalar(0.5);
dLdOut = Nd4j.scalar(0.5);
}
INDArray dLdInExpected = sgn.mul(max).mul(0.5);
INDArray dLdIn = Nd4j.createUninitialized(3, 4);

View File

@ -234,12 +234,12 @@ public class ReductionOpValidation extends BaseOpValidation {
case 10:
loss = sd.math().countNonZero("loss", input);
name = "countNonZero";
tc.expectedOutput("loss", Nd4j.trueScalar(inputArr.length()));
tc.expectedOutput("loss", Nd4j.scalar(inputArr.length()));
break;
case 11:
loss = sd.math().countZero("loss", input);
name = "countZero";
tc.expectedOutput("loss", Nd4j.trueScalar(0));
tc.expectedOutput("loss", Nd4j.scalar(0));
break;
case 12:
loss = sd.math().amax("loss", input);
@ -280,21 +280,21 @@ public class ReductionOpValidation extends BaseOpValidation {
name = "sqnorm";
loss = sd.squaredNorm("loss", input);
double norm2 = inputArr.norm2Number().doubleValue();
tc.expected("loss", Nd4j.trueScalar(norm2 * norm2));
tc.expected("loss", Nd4j.scalar(norm2 * norm2));
break;
case 19:
inputArr = Nd4j.rand(minibatch, nOut);
name = "logEntropy";
loss = sd.math().logEntropy("loss", input);
double logEntropy = inputArr.logEntropyNumber().doubleValue();
tc.expected(loss, Nd4j.trueScalar(logEntropy));
tc.expected(loss, Nd4j.scalar(logEntropy));
break;
case 20:
inputArr = Nd4j.rand(minibatch, nOut);
name = "shannonEntropy";
loss = sd.math().shannonEntropy("loss", input);
double shannonEntropy = inputArr.shannonEntropyNumber().doubleValue();
tc.expected(loss, Nd4j.trueScalar(shannonEntropy));
tc.expected(loss, Nd4j.scalar(shannonEntropy));
if (OpValidationSuite.IGNORE_FAILING) {
continue;
}

View File

@ -1030,7 +1030,7 @@ public class ShapeOpValidation extends BaseOpValidation {
//Case 1: shape is provided + scalar
sd = SameDiff.create();
ia = Nd4j.trueScalar(3.0);
ia = Nd4j.scalar(3.0);
in = sd.var(ia);
constant = sd.constant(in, 3,4,5);
INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0);
@ -1169,7 +1169,7 @@ public class ShapeOpValidation extends BaseOpValidation {
double d = new LUDecomposition(CheckUtil.convertToApacheMatrix(in)).getDeterminant();
INDArray outExp = Nd4j.trueScalar(d);
INDArray outExp = Nd4j.scalar(d);
String err = OpValidation.validate(new TestCase(sd)
.expected(md.getVarName(), outExp));
@ -1193,7 +1193,7 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(d, d2, 1e-5);
INDArray outExp = Nd4j.trueScalar(d);
INDArray outExp = Nd4j.scalar(d);
String err = OpValidation.validate(new TestCase(sd)
.expected(md.getVarName(), outExp));
@ -1224,7 +1224,7 @@ public class ShapeOpValidation extends BaseOpValidation {
- a[0][0] * a[1][2] * a[2][1];
assertEquals(d, d2, 1e-6); //Manual calc and Apache commons both match: 0.03589524995561552
INDArray outExp = Nd4j.trueScalar(d);
INDArray outExp = Nd4j.scalar(d);
String err = OpValidation.validate(new TestCase(sd)
.expected(md.getVarName(), outExp));
@ -1247,7 +1247,7 @@ public class ShapeOpValidation extends BaseOpValidation {
//System.out.println(d);
String err = OpValidation.validate(new TestCase(sd)
.expected(md.getVarName(), Nd4j.trueScalar(d)));
.expected(md.getVarName(), Nd4j.scalar(d)));
assertNull(err);
}
@ -1792,7 +1792,7 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testSplit1(){
INDArray in = Nd4j.linspace(1,10,10).reshape(10);
INDArray axis = Nd4j.trueScalar(-1);
INDArray axis = Nd4j.scalar(-1);
INDArray out1 = Nd4j.create(new long[]{5});
INDArray out2 = Nd4j.create(new long[]{5});

View File

@ -1408,7 +1408,7 @@ public class TransformOpValidation extends BaseOpValidation {
public void testScatterOpsScalar(){
for(String s : new String[]{"add", "sub", "mul", "div"}) {
INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3);
INDArray indices = Nd4j.trueScalar(5);
INDArray indices = Nd4j.scalar(5);
INDArray upd = Nd4j.create(new double[]{10, 20, 30});
//The non-scalar case works:
@ -1452,7 +1452,7 @@ public class TransformOpValidation extends BaseOpValidation {
public void testPad(){
INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0);
INDArray pad = Nd4j.create(new double[]{1,1}, new long[]{1,2}).castTo(DataType.LONG);
INDArray value = Nd4j.trueScalar(10.0);
INDArray value = Nd4j.scalar(10.0);
INDArray out = Nd4j.create(new long[]{7});

View File

@ -115,7 +115,7 @@ public class ByteOrderTests extends BaseNd4jTest {
@Test
public void testScalarEncoding() {
val scalar = Nd4j.trueScalar(2.0f);
val scalar = Nd4j.scalar(2.0f);
FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0);
val fb = scalar.toFlatArray(bufferBuilder);

View File

@ -797,7 +797,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input = Nd4j.trueScalar(4.0);
val input = Nd4j.scalar(4.0);
tg.associateArrayWithVariable(input, tg.getVariable("input_1"));
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_4.fb"));
@ -815,7 +815,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input = Nd4j.trueScalar(9.0);
val input = Nd4j.scalar(9.0);
tg.associateArrayWithVariable(input, tg.getVariable("input_1"));
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0.fb"));
@ -835,7 +835,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input0 = Nd4j.create(2, 2).assign(-4.0);
val input1 = Nd4j.trueScalar(1.0);
val input1 = Nd4j.scalar(1.0);
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
@ -855,7 +855,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input0 = Nd4j.create(2, 2).assign(-9.0);
val input1 = Nd4j.trueScalar(1.0);
val input1 = Nd4j.scalar(1.0);
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
@ -964,7 +964,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
assertNotNull(tg);
val input0 = Nd4j.create(new float[] {1, 2, 3, 4}, new int[] {2, 2});
val input1 = Nd4j.trueScalar(11f);
val input1 = Nd4j.scalar(11f);
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

View File

@ -5880,9 +5880,9 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test
public void testScalar_2() {
val scalar = Nd4j.trueScalar(2.0f);
val scalar2 = Nd4j.trueScalar(2.0f);
val scalar3 = Nd4j.trueScalar(3.0f);
val scalar = Nd4j.scalar(2.0f);
val scalar2 = Nd4j.scalar(2.0f);
val scalar3 = Nd4j.scalar(3.0f);
assertTrue(scalar.isScalar());
assertEquals(1, scalar.length());
@ -5917,7 +5917,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test
public void testVectorScalar_2() {
val vector = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5});
val scalar = Nd4j.trueScalar(2.0f);
val scalar = Nd4j.scalar(2.0f);
val exp = Nd4j.trueVector(new float[]{3, 4, 5, 6, 7});
vector.addi(scalar);
@ -5927,7 +5927,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test
public void testReshapeScalar() {
val scalar = Nd4j.trueScalar(2.0f);
val scalar = Nd4j.scalar(2.0f);
val newShape = scalar.reshape(1, 1, 1, 1);
assertEquals(4, newShape.rank());
@ -5958,7 +5958,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test(expected = IllegalStateException.class)
public void testTranspose2() {
val scalar = Nd4j.trueScalar(2.f);
val scalar = Nd4j.scalar(2.f);
assertArrayEquals(new long[]{}, scalar.shape());
assertArrayEquals(new long[]{}, scalar.stride());
@ -5991,8 +5991,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test
public void testScalarSqueeze() {
val scalar = Nd4j.create(new float[]{2.0f}, new long[]{1, 1});
val output = Nd4j.trueScalar(0.0f);
val exp = Nd4j.trueScalar(2.0f);
val output = Nd4j.scalar(0.0f);
val exp = Nd4j.scalar(2.0f);
val op = DynamicCustomOp.builder("squeeze")
.addInputs(scalar)
.addOutputs(output)
@ -6012,8 +6012,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertArrayEquals(new long[]{1}, scalar.shape());
val output = Nd4j.trueScalar(0.0f);
val exp = Nd4j.trueScalar(2.0f);
val output = Nd4j.scalar(0.0f);
val exp = Nd4j.scalar(2.0f);
val op = DynamicCustomOp.builder("squeeze")
.addInputs(scalar)
.addOutputs(output)
@ -6113,7 +6113,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test
public void testValueArrayOf_2() {
val scalar = Nd4j.valueArrayOf(new long[] {}, 2f);
val exp = Nd4j.trueScalar(2f);
val exp = Nd4j.scalar(2f);
assertArrayEquals(exp.shape(), scalar.shape());
assertEquals(exp, scalar);
@ -6873,7 +6873,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
val exp_2 = Nd4j.create(new double[]{0.0, 1.0, 2.0}, new long[]{3});
val exp_3 = Nd4j.create(new double[]{1.0, 2.0, 3.0}, new long[]{3});
val arrayX = Nd4j.create(new double[]{1.0, 2.0, 3.0}, new long[]{3});
val arrayY = Nd4j.trueScalar(1.0);
val arrayY = Nd4j.scalar(1.0);
val arrayZ_1 = arrayX.add(arrayY);
assertEquals(exp_1, arrayZ_1);

View File

@ -104,7 +104,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
@Test
public void testBasicCreation_5() {
val scalar = Nd4j.trueScalar(new Integer(1));
val scalar = Nd4j.scalar(new Integer(1));
assertNotNull(scalar);
assertEquals(0, scalar.rank());
assertEquals(1, scalar.length());
@ -114,7 +114,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
@Test
public void testBasicCreation_6() {
val scalar = Nd4j.trueScalar(1);
val scalar = Nd4j.scalar(1);
assertNotNull(scalar);
assertEquals(0, scalar.rank());
assertEquals(1, scalar.length());
@ -124,7 +124,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
@Test
public void testBasicCreation_7() {
val scalar = Nd4j.trueScalar(1L);
val scalar = Nd4j.scalar(1L);
assertNotNull(scalar);
assertEquals(0, scalar.rank());
assertEquals(1, scalar.length());

View File

@ -439,7 +439,7 @@ public class ShapeTestsC extends BaseNd4jTest {
@Test
public void testReshapeToTrueScalar_1() {
val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1});
val exp = Nd4j.trueScalar(1.0f);
val exp = Nd4j.scalar(1.0f);
assertArrayEquals(new long[]{1, 1}, orig.shape());
@ -452,7 +452,7 @@ public class ShapeTestsC extends BaseNd4jTest {
@Test
public void testReshapeToTrueScalar_2() {
val orig = Nd4j.create(new float[]{1.0f}, new int[]{1});
val exp = Nd4j.trueScalar(1.0f);
val exp = Nd4j.scalar(1.0f);
assertArrayEquals(new long[]{1}, orig.shape());
@ -478,7 +478,7 @@ public class ShapeTestsC extends BaseNd4jTest {
@Test
public void testReshapeToTrueScalar_4() {
val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1});
val exp = Nd4j.trueScalar(1.0f);
val exp = Nd4j.scalar(1.0f);
assertArrayEquals(new long[]{1, 1}, orig.shape());