diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java index e316ded84..2cb7c9eeb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java @@ -17,7 +17,6 @@ package org.nd4j.linalg.factory; -import com.google.common.util.concurrent.AtomicDouble; import lombok.val; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.blas.*; @@ -25,13 +24,13 @@ import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.random.impl.Range; import org.nd4j.linalg.api.rng.distribution.Distribution; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.primitives.AtomicDouble; import org.nd4j.linalg.util.ArrayUtil; import java.util.*; @@ -1236,26 +1235,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { return create(shape, Nd4j.getStrides(shape), 0); } - - /** - * Create a scalar ndarray with the specified offset - * - * @param value the value to initialize the scalar with - * @param offset the offset of the ndarray - * @return the created ndarray - */ - @Override - public INDArray scalar(Number value, long offset) { - if (Nd4j.dataType() == DataType.DOUBLE) - return scalar(value.doubleValue(), offset); - if (Nd4j.dataType() == DataType.FLOAT || Nd4j.dataType() == DataType.HALF) - return scalar(value.floatValue(), offset); - if (Nd4j.dataType() == DataType.INT) - return scalar(value.intValue(), offset); - throw new IllegalStateException("Illegal data opType " + Nd4j.dataType()); - } - - /** * Create a scalar nd array with the specified value and offset * @@ -1320,7 +1299,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { return create(new int[] {value}, new long[0], new long[0], DataType.INT, Nd4j.getMemoryManager().getCurrentWorkspace()); } - /** * Create a scalar ndarray with the specified offset * @@ -1329,30 +1307,22 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { */ @Override public INDArray scalar(Number value) { - if (value instanceof Double) - return create(new double[]{value.doubleValue()}, new long[0], new long[0], DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); - else if (value instanceof Float) - return create(new float[]{value.floatValue()}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace()); - else if (value instanceof Long) - return create(new long[]{value.longValue()}, new long[0], new long[0], DataType.LONG, Nd4j.getMemoryManager().getCurrentWorkspace()); - else if (value instanceof Integer) - return create(new int[]{value.intValue()}, new long[0], new long[0], DataType.INT, Nd4j.getMemoryManager().getCurrentWorkspace()); - else if (value instanceof Short) - return create(new short[]{value.shortValue()}, new long[0], new long[0], DataType.SHORT, Nd4j.getMemoryManager().getCurrentWorkspace()); - else if (value instanceof Byte) - return create(new byte[]{value.byteValue()}, new long[0], new long[0], DataType.BYTE, Nd4j.getMemoryManager().getCurrentWorkspace()); - throw new IllegalStateException("Unknown instance of Number: [" + value.getClass().getCanonicalName() + "]"); - } + MemoryWorkspace ws = Nd4j.getMemoryManager().getCurrentWorkspace(); - /** - * Create a scalar nd array with the specified value and offset - * - * @param value the value of the scalar - * = * @return the scalar nd array - */ - @Override - public INDArray scalar(float value) { - return create(new float[] {value}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace()); + if (value instanceof Double || value instanceof AtomicDouble) /* note that org.nd4j.linalg.primitives.AtomicDouble extends com.google.common.util.concurrent.AtomicDouble */ + return scalar(value.doubleValue()); + else if (value instanceof Float) + return scalar(value.floatValue()); + else if (value instanceof Long || value instanceof AtomicLong) + return create(new long[] {value.longValue()}, new long[] {}, new long[] {}, DataType.LONG, ws); + else if (value instanceof Integer || value instanceof AtomicInteger) + return create(new int[] {value.intValue()}, new long[] {}, new long[] {}, DataType.INT, ws); + else if (value instanceof Short) + return create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, DataType.SHORT, ws); + else if (value instanceof Byte) + return create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, DataType.BYTE, ws); + else + throw new UnsupportedOperationException("Unsupported data type: [" + value.getClass().getSimpleName() + "]"); } /** @@ -1366,6 +1336,11 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { return create(new double[] {value}, new long[0], new long[0], DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); } + @Override + public INDArray scalar(float value) { + return create(new float[] {value}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace()); + } + @Override public INDArray create(float[] data, int[] shape, long offset) { return create(Nd4j.createBuffer(data), shape, offset); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java index eaee86617..0cd52b1d2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java @@ -949,16 +949,6 @@ public interface NDArrayFactory { */ INDArray create(int[] shape); - /** - * Create a scalar ndarray with the specified offset - * - * @param value the value to initialize the scalar with - * @param offset the offset of the ndarray - * @return the created ndarray - */ - INDArray scalar(Number value, long offset); - - /** * Create a scalar nd array with the specified value and offset * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 683acc0d6..eb3250e3c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -3714,46 +3714,6 @@ public class Nd4j { return INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace()); } - /** - * @deprecated Use {@link #createFromArray(boolean...)} - */ - @Deprecated - public static INDArray trueVector(boolean[] data) { - return INSTANCE.trueVector(data); - } - - /** - * @deprecated Use {@link #createFromArray(long...)} - */ - @Deprecated - public static INDArray trueVector(long[] data) { - return INSTANCE.trueVector(data); - } - - /** - * @deprecated Use {@link #createFromArray(int...)} - */ - @Deprecated - public static INDArray trueVector(int[] data) { - return INSTANCE.trueVector(data); - } - - /** - * @deprecated Use {@link #createFromArray(float...)} - */ - @Deprecated - public static INDArray trueVector(float[] data) { - return INSTANCE.trueVector(data); - } - - /** - * @deprecated Use {@link #createFromArray(double...)} - */ - @Deprecated - public static INDArray trueVector(double[] data) { - return INSTANCE.trueVector(data); - } - /** * This method creates "empty" INDArray with datatype determined by {@link #dataType()} * diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index 57f072aac..2a4b032b5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -796,7 +796,7 @@ public class MiscOpValidation extends BaseOpValidation { @Test public void testFillOp(){ - INDArray ia = Nd4j.trueVector(new double[]{2,2}).castTo(DataType.INT); + INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT); double value = 42; INDArray out = Nd4j.create(DataType.FLOAT, 2,2); OpTestCase op = new OpTestCase(new Fill(ia, out, value)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java index 9b6b6bbef..646cae454 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java @@ -289,7 +289,7 @@ public class RandomOpValidation extends BaseOpValidation { @Test public void testUniformRankSimple() { - INDArray arr = Nd4j.trueVector(new double[]{100.0}); + INDArray arr = Nd4j.createFromArray(new double[]{100.0}); // OpTestCase tc = new OpTestCase(DynamicCustomOp.builder("randomuniform") // .addInputs(arr) // .addOutputs(Nd4j.createUninitialized(new long[]{100})) @@ -321,7 +321,7 @@ public class RandomOpValidation extends BaseOpValidation { @Test public void testRandomExponential() { long length = 1_000_000; - INDArray shape = Nd4j.trueVector(new double[]{length}); + INDArray shape = Nd4j.createFromArray(new double[]{length}); INDArray out = Nd4j.createUninitialized(new long[]{length}); double lambda = 2; RandomExponential op = new RandomExponential(shape, out, lambda); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java index 8eff8e532..c2ca2148e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java @@ -133,7 +133,7 @@ public class ByteOrderTests extends BaseNd4jTest { @Test public void testVectorEncoding_1() { - val scalar = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5}); + val scalar = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5}); FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0); val fb = scalar.toFlatArray(bufferBuilder); @@ -149,7 +149,7 @@ public class ByteOrderTests extends BaseNd4jTest { @Test public void testVectorEncoding_2() { - val scalar = Nd4j.trueVector(new double[]{1, 2, 3, 4, 5}); + val scalar = Nd4j.createFromArray(new double[]{1, 2, 3, 4, 5}); FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0); val fb = scalar.toFlatArray(bufferBuilder); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java index c6878d1be..a501f9ff4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java @@ -132,7 +132,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { log.info(graph.asFlatPrint()); val result = graph.execAndEndResult(); - val exp = Nd4j.trueVector(new long[]{2, 2, 2}); + val exp = Nd4j.createFromArray(new long[]{2, 2, 2}); assertEquals(exp, result); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TestReverse.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TestReverse.java index c5d96746f..b54c7985e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TestReverse.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TestReverse.java @@ -38,7 +38,7 @@ public class TestReverse extends BaseNd4jTest { @Test public void testReverse(){ - INDArray in = Nd4j.trueVector(new double[]{1,2,3,4,5,6}); + INDArray in = Nd4j.createFromArray(new double[]{1,2,3,4,5,6}); INDArray out = Nd4j.create(DataType.DOUBLE, 6); DynamicCustomOp op = DynamicCustomOp.builder("reverse") @@ -55,7 +55,7 @@ public class TestReverse extends BaseNd4jTest { @Test public void testReverse2(){ - INDArray in = Nd4j.trueVector(new double[]{1,2,3,4,5,6}); + INDArray in = Nd4j.createFromArray(new double[]{1,2,3,4,5,6}); INDArray axis = Nd4j.scalar(0); INDArray out = Nd4j.create(DataType.DOUBLE, 6); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index b948220fd..b302c8c0f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -5899,9 +5899,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testVector_1() { - val vector = Nd4j.trueVector(new float[] {1, 2, 3, 4, 5}); - val vector2 = Nd4j.trueVector(new float[] {1, 2, 3, 4, 5}); - val vector3 = Nd4j.trueVector(new float[] {1, 2, 3, 4, 6}); + val vector = Nd4j.createFromArray(new float[] {1, 2, 3, 4, 5}); + val vector2 = Nd4j.createFromArray(new float[] {1, 2, 3, 4, 5}); + val vector3 = Nd4j.createFromArray(new float[] {1, 2, 3, 4, 6}); assertFalse(vector.isScalar()); assertEquals(5, vector.length()); @@ -5916,9 +5916,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testVectorScalar_2() { - val vector = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5}); + val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5}); val scalar = Nd4j.scalar(2.0f); - val exp = Nd4j.trueVector(new float[]{3, 4, 5, 6, 7}); + val exp = Nd4j.createFromArray(new float[]{3, 4, 5, 6, 7}); vector.addi(scalar); @@ -5937,7 +5937,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testReshapeVector() { - val vector = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5, 6}); + val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6}); val newShape = vector.reshape(3, 2); assertEquals(2, newShape.rank()); @@ -5946,7 +5946,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test(expected = IllegalStateException.class) public void testTranspose1() { - val vector = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5, 6}); + val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6}); assertArrayEquals(new long[]{6}, vector.shape()); assertArrayEquals(new long[]{1}, vector.stride()); @@ -6030,8 +6030,8 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testVectorSqueeze() { val vector = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6}, new long[]{1, 6}); - val output = Nd4j.trueVector(new float[] {0, 0, 0, 0, 0, 0}); - val exp = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5, 6}); + val output = Nd4j.createFromArray(new float[] {0, 0, 0, 0, 0, 0}); + val exp = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6}); val op = DynamicCustomOp.builder("squeeze") .addInputs(vector) @@ -6078,11 +6078,11 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testVectorScalarConcat() { - val vector = Nd4j.trueVector(new float[] {1, 2}); + val vector = Nd4j.createFromArray(new float[] {1, 2}); val scalar = Nd4j.scalar(3.0f); - val output = Nd4j.trueVector(new float[]{0, 0, 0}); - val exp = Nd4j.trueVector(new float[]{1, 2, 3}); + val output = Nd4j.createFromArray(new float[]{0, 0, 0}); + val exp = Nd4j.createFromArray(new float[]{1, 2, 3}); val op = DynamicCustomOp.builder("concat") .addInputs(vector, scalar) @@ -6103,7 +6103,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testValueArrayOf_1() { val vector = Nd4j.valueArrayOf(new long[] {5}, 2f, DataType.FLOAT); - val exp = Nd4j.trueVector(new float[]{2, 2, 2, 2, 2}); + val exp = Nd4j.createFromArray(new float[]{2, 2, 2, 2, 2}); assertArrayEquals(exp.shape(), vector.shape()); assertEquals(exp, vector); @@ -6123,7 +6123,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testArrayCreation() { val vector = Nd4j.create(new float[]{1, 2, 3}, new long[] {3}, 'c'); - val exp = Nd4j.trueVector(new float[]{1, 2, 3}); + val exp = Nd4j.createFromArray(new float[]{1, 2, 3}); assertArrayEquals(exp.shape(), vector.shape()); assertEquals(exp, vector); @@ -6964,8 +6964,8 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray arr = Nd4j.create(new boolean[][]{{false,true,false},{false,false,true},{false,false,true}}); INDArray[] exp = new INDArray[]{ - Nd4j.trueVector(new long[]{0,1,2}), - Nd4j.trueVector(new long[]{1,2,2})}; + Nd4j.createFromArray(new long[]{0,1,2}), + Nd4j.createFromArray(new long[]{1,2,2})}; INDArray[] act = Nd4j.where(arr, null, null); @@ -6980,9 +6980,9 @@ public class Nd4jTestsC extends BaseNd4jTest { arr.putScalar(1,2,1,1.0); arr.putScalar(2,2,1,1.0); INDArray[] exp = new INDArray[]{ - Nd4j.trueVector(new long[]{0,1,2}), - Nd4j.trueVector(new long[]{1,2,2}), - Nd4j.trueVector(new long[]{0,1,1}) + Nd4j.createFromArray(new long[]{0,1,2}), + Nd4j.createFromArray(new long[]{1,2,2}), + Nd4j.createFromArray(new long[]{0,1,1}) }; INDArray[] act = Nd4j.where(arr, null, null); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java index f6d0fb3d3..f4f3e67f2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java @@ -465,7 +465,7 @@ public class ShapeTestsC extends BaseNd4jTest { @Test public void testReshapeToTrueScalar_3() { val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1}); - val exp = Nd4j.trueVector(new float[]{1.0f}); + val exp = Nd4j.createFromArray(new float[]{1.0f}); assertArrayEquals(new long[]{1, 1}, orig.shape());