ND4J: Remove Nd4j.trueScalar/trueVector (#145)
* merge conflict. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * remove/replace trueVector Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com>master
parent
2b0d7b3b52
commit
ca7e5593ec
|
@ -17,7 +17,6 @@
|
|||
package org.nd4j.linalg.factory;
|
||||
|
||||
|
||||
import com.google.common.util.concurrent.AtomicDouble;
|
||||
import lombok.val;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.blas.*;
|
||||
|
@ -25,13 +24,13 @@ import org.nd4j.linalg.api.buffer.DataBuffer;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.CustomOp;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.random.impl.Range;
|
||||
import org.nd4j.linalg.api.rng.distribution.Distribution;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.primitives.AtomicDouble;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.*;
|
||||
|
@ -1236,26 +1235,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
|||
return create(shape, Nd4j.getStrides(shape), 0);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Create a scalar ndarray with the specified offset
|
||||
*
|
||||
* @param value the value to initialize the scalar with
|
||||
* @param offset the offset of the ndarray
|
||||
* @return the created ndarray
|
||||
*/
|
||||
@Override
|
||||
public INDArray scalar(Number value, long offset) {
|
||||
if (Nd4j.dataType() == DataType.DOUBLE)
|
||||
return scalar(value.doubleValue(), offset);
|
||||
if (Nd4j.dataType() == DataType.FLOAT || Nd4j.dataType() == DataType.HALF)
|
||||
return scalar(value.floatValue(), offset);
|
||||
if (Nd4j.dataType() == DataType.INT)
|
||||
return scalar(value.intValue(), offset);
|
||||
throw new IllegalStateException("Illegal data opType " + Nd4j.dataType());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Create a scalar nd array with the specified value and offset
|
||||
*
|
||||
|
@ -1320,7 +1299,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
|||
return create(new int[] {value}, new long[0], new long[0], DataType.INT, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Create a scalar ndarray with the specified offset
|
||||
*
|
||||
|
@ -1329,30 +1307,22 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
|||
*/
|
||||
@Override
|
||||
public INDArray scalar(Number value) {
|
||||
if (value instanceof Double)
|
||||
return create(new double[]{value.doubleValue()}, new long[0], new long[0], DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
else if (value instanceof Float)
|
||||
return create(new float[]{value.floatValue()}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
else if (value instanceof Long)
|
||||
return create(new long[]{value.longValue()}, new long[0], new long[0], DataType.LONG, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
else if (value instanceof Integer)
|
||||
return create(new int[]{value.intValue()}, new long[0], new long[0], DataType.INT, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
else if (value instanceof Short)
|
||||
return create(new short[]{value.shortValue()}, new long[0], new long[0], DataType.SHORT, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
else if (value instanceof Byte)
|
||||
return create(new byte[]{value.byteValue()}, new long[0], new long[0], DataType.BYTE, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
throw new IllegalStateException("Unknown instance of Number: [" + value.getClass().getCanonicalName() + "]");
|
||||
}
|
||||
MemoryWorkspace ws = Nd4j.getMemoryManager().getCurrentWorkspace();
|
||||
|
||||
/**
|
||||
* Create a scalar nd array with the specified value and offset
|
||||
*
|
||||
* @param value the value of the scalar
|
||||
* = * @return the scalar nd array
|
||||
*/
|
||||
@Override
|
||||
public INDArray scalar(float value) {
|
||||
return create(new float[] {value}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
if (value instanceof Double || value instanceof AtomicDouble) /* note that org.nd4j.linalg.primitives.AtomicDouble extends com.google.common.util.concurrent.AtomicDouble */
|
||||
return scalar(value.doubleValue());
|
||||
else if (value instanceof Float)
|
||||
return scalar(value.floatValue());
|
||||
else if (value instanceof Long || value instanceof AtomicLong)
|
||||
return create(new long[] {value.longValue()}, new long[] {}, new long[] {}, DataType.LONG, ws);
|
||||
else if (value instanceof Integer || value instanceof AtomicInteger)
|
||||
return create(new int[] {value.intValue()}, new long[] {}, new long[] {}, DataType.INT, ws);
|
||||
else if (value instanceof Short)
|
||||
return create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, DataType.SHORT, ws);
|
||||
else if (value instanceof Byte)
|
||||
return create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, DataType.BYTE, ws);
|
||||
else
|
||||
throw new UnsupportedOperationException("Unsupported data type: [" + value.getClass().getSimpleName() + "]");
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1366,6 +1336,11 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
|||
return create(new double[] {value}, new long[0], new long[0], DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray scalar(float value) {
|
||||
return create(new float[] {value}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray create(float[] data, int[] shape, long offset) {
|
||||
return create(Nd4j.createBuffer(data), shape, offset);
|
||||
|
|
|
@ -949,16 +949,6 @@ public interface NDArrayFactory {
|
|||
*/
|
||||
INDArray create(int[] shape);
|
||||
|
||||
/**
|
||||
* Create a scalar ndarray with the specified offset
|
||||
*
|
||||
* @param value the value to initialize the scalar with
|
||||
* @param offset the offset of the ndarray
|
||||
* @return the created ndarray
|
||||
*/
|
||||
INDArray scalar(Number value, long offset);
|
||||
|
||||
|
||||
/**
|
||||
* Create a scalar nd array with the specified value and offset
|
||||
*
|
||||
|
|
|
@ -3714,46 +3714,6 @@ public class Nd4j {
|
|||
return INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #createFromArray(boolean...)}
|
||||
*/
|
||||
@Deprecated
|
||||
public static INDArray trueVector(boolean[] data) {
|
||||
return INSTANCE.trueVector(data);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #createFromArray(long...)}
|
||||
*/
|
||||
@Deprecated
|
||||
public static INDArray trueVector(long[] data) {
|
||||
return INSTANCE.trueVector(data);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #createFromArray(int...)}
|
||||
*/
|
||||
@Deprecated
|
||||
public static INDArray trueVector(int[] data) {
|
||||
return INSTANCE.trueVector(data);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #createFromArray(float...)}
|
||||
*/
|
||||
@Deprecated
|
||||
public static INDArray trueVector(float[] data) {
|
||||
return INSTANCE.trueVector(data);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #createFromArray(double...)}
|
||||
*/
|
||||
@Deprecated
|
||||
public static INDArray trueVector(double[] data) {
|
||||
return INSTANCE.trueVector(data);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method creates "empty" INDArray with datatype determined by {@link #dataType()}
|
||||
*
|
||||
|
|
|
@ -796,7 +796,7 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
@Test
|
||||
public void testFillOp(){
|
||||
|
||||
INDArray ia = Nd4j.trueVector(new double[]{2,2}).castTo(DataType.INT);
|
||||
INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT);
|
||||
double value = 42;
|
||||
INDArray out = Nd4j.create(DataType.FLOAT, 2,2);
|
||||
OpTestCase op = new OpTestCase(new Fill(ia, out, value));
|
||||
|
|
|
@ -289,7 +289,7 @@ public class RandomOpValidation extends BaseOpValidation {
|
|||
@Test
|
||||
public void testUniformRankSimple() {
|
||||
|
||||
INDArray arr = Nd4j.trueVector(new double[]{100.0});
|
||||
INDArray arr = Nd4j.createFromArray(new double[]{100.0});
|
||||
// OpTestCase tc = new OpTestCase(DynamicCustomOp.builder("randomuniform")
|
||||
// .addInputs(arr)
|
||||
// .addOutputs(Nd4j.createUninitialized(new long[]{100}))
|
||||
|
@ -321,7 +321,7 @@ public class RandomOpValidation extends BaseOpValidation {
|
|||
@Test
|
||||
public void testRandomExponential() {
|
||||
long length = 1_000_000;
|
||||
INDArray shape = Nd4j.trueVector(new double[]{length});
|
||||
INDArray shape = Nd4j.createFromArray(new double[]{length});
|
||||
INDArray out = Nd4j.createUninitialized(new long[]{length});
|
||||
double lambda = 2;
|
||||
RandomExponential op = new RandomExponential(shape, out, lambda);
|
||||
|
|
|
@ -133,7 +133,7 @@ public class ByteOrderTests extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
public void testVectorEncoding_1() {
|
||||
val scalar = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5});
|
||||
val scalar = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5});
|
||||
|
||||
FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0);
|
||||
val fb = scalar.toFlatArray(bufferBuilder);
|
||||
|
@ -149,7 +149,7 @@ public class ByteOrderTests extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
public void testVectorEncoding_2() {
|
||||
val scalar = Nd4j.trueVector(new double[]{1, 2, 3, 4, 5});
|
||||
val scalar = Nd4j.createFromArray(new double[]{1, 2, 3, 4, 5});
|
||||
|
||||
FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0);
|
||||
val fb = scalar.toFlatArray(bufferBuilder);
|
||||
|
|
|
@ -132,7 +132,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
log.info(graph.asFlatPrint());
|
||||
val result = graph.execAndEndResult();
|
||||
|
||||
val exp = Nd4j.trueVector(new long[]{2, 2, 2});
|
||||
val exp = Nd4j.createFromArray(new long[]{2, 2, 2});
|
||||
|
||||
assertEquals(exp, result);
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ public class TestReverse extends BaseNd4jTest {
|
|||
@Test
|
||||
public void testReverse(){
|
||||
|
||||
INDArray in = Nd4j.trueVector(new double[]{1,2,3,4,5,6});
|
||||
INDArray in = Nd4j.createFromArray(new double[]{1,2,3,4,5,6});
|
||||
INDArray out = Nd4j.create(DataType.DOUBLE, 6);
|
||||
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("reverse")
|
||||
|
@ -55,7 +55,7 @@ public class TestReverse extends BaseNd4jTest {
|
|||
@Test
|
||||
public void testReverse2(){
|
||||
|
||||
INDArray in = Nd4j.trueVector(new double[]{1,2,3,4,5,6});
|
||||
INDArray in = Nd4j.createFromArray(new double[]{1,2,3,4,5,6});
|
||||
INDArray axis = Nd4j.scalar(0);
|
||||
INDArray out = Nd4j.create(DataType.DOUBLE, 6);
|
||||
|
||||
|
|
|
@ -5899,9 +5899,9 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
public void testVector_1() {
|
||||
val vector = Nd4j.trueVector(new float[] {1, 2, 3, 4, 5});
|
||||
val vector2 = Nd4j.trueVector(new float[] {1, 2, 3, 4, 5});
|
||||
val vector3 = Nd4j.trueVector(new float[] {1, 2, 3, 4, 6});
|
||||
val vector = Nd4j.createFromArray(new float[] {1, 2, 3, 4, 5});
|
||||
val vector2 = Nd4j.createFromArray(new float[] {1, 2, 3, 4, 5});
|
||||
val vector3 = Nd4j.createFromArray(new float[] {1, 2, 3, 4, 6});
|
||||
|
||||
assertFalse(vector.isScalar());
|
||||
assertEquals(5, vector.length());
|
||||
|
@ -5916,9 +5916,9 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
public void testVectorScalar_2() {
|
||||
val vector = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5});
|
||||
val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5});
|
||||
val scalar = Nd4j.scalar(2.0f);
|
||||
val exp = Nd4j.trueVector(new float[]{3, 4, 5, 6, 7});
|
||||
val exp = Nd4j.createFromArray(new float[]{3, 4, 5, 6, 7});
|
||||
|
||||
vector.addi(scalar);
|
||||
|
||||
|
@ -5937,7 +5937,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
public void testReshapeVector() {
|
||||
val vector = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5, 6});
|
||||
val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6});
|
||||
val newShape = vector.reshape(3, 2);
|
||||
|
||||
assertEquals(2, newShape.rank());
|
||||
|
@ -5946,7 +5946,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void testTranspose1() {
|
||||
val vector = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5, 6});
|
||||
val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6});
|
||||
|
||||
assertArrayEquals(new long[]{6}, vector.shape());
|
||||
assertArrayEquals(new long[]{1}, vector.stride());
|
||||
|
@ -6030,8 +6030,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
@Test
|
||||
public void testVectorSqueeze() {
|
||||
val vector = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6}, new long[]{1, 6});
|
||||
val output = Nd4j.trueVector(new float[] {0, 0, 0, 0, 0, 0});
|
||||
val exp = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5, 6});
|
||||
val output = Nd4j.createFromArray(new float[] {0, 0, 0, 0, 0, 0});
|
||||
val exp = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6});
|
||||
|
||||
val op = DynamicCustomOp.builder("squeeze")
|
||||
.addInputs(vector)
|
||||
|
@ -6078,11 +6078,11 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
public void testVectorScalarConcat() {
|
||||
val vector = Nd4j.trueVector(new float[] {1, 2});
|
||||
val vector = Nd4j.createFromArray(new float[] {1, 2});
|
||||
val scalar = Nd4j.scalar(3.0f);
|
||||
|
||||
val output = Nd4j.trueVector(new float[]{0, 0, 0});
|
||||
val exp = Nd4j.trueVector(new float[]{1, 2, 3});
|
||||
val output = Nd4j.createFromArray(new float[]{0, 0, 0});
|
||||
val exp = Nd4j.createFromArray(new float[]{1, 2, 3});
|
||||
|
||||
val op = DynamicCustomOp.builder("concat")
|
||||
.addInputs(vector, scalar)
|
||||
|
@ -6103,7 +6103,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
@Test
|
||||
public void testValueArrayOf_1() {
|
||||
val vector = Nd4j.valueArrayOf(new long[] {5}, 2f, DataType.FLOAT);
|
||||
val exp = Nd4j.trueVector(new float[]{2, 2, 2, 2, 2});
|
||||
val exp = Nd4j.createFromArray(new float[]{2, 2, 2, 2, 2});
|
||||
|
||||
assertArrayEquals(exp.shape(), vector.shape());
|
||||
assertEquals(exp, vector);
|
||||
|
@ -6123,7 +6123,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
@Test
|
||||
public void testArrayCreation() {
|
||||
val vector = Nd4j.create(new float[]{1, 2, 3}, new long[] {3}, 'c');
|
||||
val exp = Nd4j.trueVector(new float[]{1, 2, 3});
|
||||
val exp = Nd4j.createFromArray(new float[]{1, 2, 3});
|
||||
|
||||
assertArrayEquals(exp.shape(), vector.shape());
|
||||
assertEquals(exp, vector);
|
||||
|
@ -6964,8 +6964,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
|
||||
INDArray arr = Nd4j.create(new boolean[][]{{false,true,false},{false,false,true},{false,false,true}});
|
||||
INDArray[] exp = new INDArray[]{
|
||||
Nd4j.trueVector(new long[]{0,1,2}),
|
||||
Nd4j.trueVector(new long[]{1,2,2})};
|
||||
Nd4j.createFromArray(new long[]{0,1,2}),
|
||||
Nd4j.createFromArray(new long[]{1,2,2})};
|
||||
|
||||
INDArray[] act = Nd4j.where(arr, null, null);
|
||||
|
||||
|
@ -6980,9 +6980,9 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
arr.putScalar(1,2,1,1.0);
|
||||
arr.putScalar(2,2,1,1.0);
|
||||
INDArray[] exp = new INDArray[]{
|
||||
Nd4j.trueVector(new long[]{0,1,2}),
|
||||
Nd4j.trueVector(new long[]{1,2,2}),
|
||||
Nd4j.trueVector(new long[]{0,1,1})
|
||||
Nd4j.createFromArray(new long[]{0,1,2}),
|
||||
Nd4j.createFromArray(new long[]{1,2,2}),
|
||||
Nd4j.createFromArray(new long[]{0,1,1})
|
||||
};
|
||||
|
||||
INDArray[] act = Nd4j.where(arr, null, null);
|
||||
|
|
|
@ -465,7 +465,7 @@ public class ShapeTestsC extends BaseNd4jTest {
|
|||
@Test
|
||||
public void testReshapeToTrueScalar_3() {
|
||||
val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1});
|
||||
val exp = Nd4j.trueVector(new float[]{1.0f});
|
||||
val exp = Nd4j.createFromArray(new float[]{1.0f});
|
||||
|
||||
assertArrayEquals(new long[]{1, 1}, orig.shape());
|
||||
|
||||
|
|
Loading…
Reference in New Issue