ND4J: Remove Nd4j.trueScalar/trueVector (#145)

* merge conflict.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* remove/replace trueVector

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-08-22 10:49:30 +09:00 committed by GitHub
parent 2b0d7b3b52
commit ca7e5593ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 49 additions and 124 deletions

View File

@ -17,7 +17,6 @@
package org.nd4j.linalg.factory;
import com.google.common.util.concurrent.AtomicDouble;
import lombok.val;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.blas.*;
@ -25,13 +24,13 @@ import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.random.impl.Range;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.AtomicDouble;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.*;
@ -1236,26 +1235,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
return create(shape, Nd4j.getStrides(shape), 0);
}
/**
* Create a scalar ndarray with the specified offset
*
* @param value the value to initialize the scalar with
* @param offset the offset of the ndarray
* @return the created ndarray
*/
@Override
public INDArray scalar(Number value, long offset) {
if (Nd4j.dataType() == DataType.DOUBLE)
return scalar(value.doubleValue(), offset);
if (Nd4j.dataType() == DataType.FLOAT || Nd4j.dataType() == DataType.HALF)
return scalar(value.floatValue(), offset);
if (Nd4j.dataType() == DataType.INT)
return scalar(value.intValue(), offset);
throw new IllegalStateException("Illegal data opType " + Nd4j.dataType());
}
/**
* Create a scalar nd array with the specified value and offset
*
@ -1320,7 +1299,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
return create(new int[] {value}, new long[0], new long[0], DataType.INT, Nd4j.getMemoryManager().getCurrentWorkspace());
}
/**
* Create a scalar ndarray with the specified offset
*
@ -1329,30 +1307,22 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
*/
@Override
public INDArray scalar(Number value) {
if (value instanceof Double)
return create(new double[]{value.doubleValue()}, new long[0], new long[0], DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (value instanceof Float)
return create(new float[]{value.floatValue()}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (value instanceof Long)
return create(new long[]{value.longValue()}, new long[0], new long[0], DataType.LONG, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (value instanceof Integer)
return create(new int[]{value.intValue()}, new long[0], new long[0], DataType.INT, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (value instanceof Short)
return create(new short[]{value.shortValue()}, new long[0], new long[0], DataType.SHORT, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (value instanceof Byte)
return create(new byte[]{value.byteValue()}, new long[0], new long[0], DataType.BYTE, Nd4j.getMemoryManager().getCurrentWorkspace());
throw new IllegalStateException("Unknown instance of Number: [" + value.getClass().getCanonicalName() + "]");
}
MemoryWorkspace ws = Nd4j.getMemoryManager().getCurrentWorkspace();
/**
* Create a scalar nd array with the specified value and offset
*
* @param value the value of the scalar
* = * @return the scalar nd array
*/
@Override
public INDArray scalar(float value) {
return create(new float[] {value}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
if (value instanceof Double || value instanceof AtomicDouble) /* note that org.nd4j.linalg.primitives.AtomicDouble extends com.google.common.util.concurrent.AtomicDouble */
return scalar(value.doubleValue());
else if (value instanceof Float)
return scalar(value.floatValue());
else if (value instanceof Long || value instanceof AtomicLong)
return create(new long[] {value.longValue()}, new long[] {}, new long[] {}, DataType.LONG, ws);
else if (value instanceof Integer || value instanceof AtomicInteger)
return create(new int[] {value.intValue()}, new long[] {}, new long[] {}, DataType.INT, ws);
else if (value instanceof Short)
return create(new short[] {value.shortValue()}, new long[] {}, new long[] {}, DataType.SHORT, ws);
else if (value instanceof Byte)
return create(new byte[] {value.byteValue()}, new long[] {}, new long[] {}, DataType.BYTE, ws);
else
throw new UnsupportedOperationException("Unsupported data type: [" + value.getClass().getSimpleName() + "]");
}
/**
@ -1366,6 +1336,11 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
return create(new double[] {value}, new long[0], new long[0], DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@Override
public INDArray scalar(float value) {
return create(new float[] {value}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@Override
public INDArray create(float[] data, int[] shape, long offset) {
return create(Nd4j.createBuffer(data), shape, offset);

View File

@ -949,16 +949,6 @@ public interface NDArrayFactory {
*/
INDArray create(int[] shape);
/**
* Create a scalar ndarray with the specified offset
*
* @param value the value to initialize the scalar with
* @param offset the offset of the ndarray
* @return the created ndarray
*/
INDArray scalar(Number value, long offset);
/**
* Create a scalar nd array with the specified value and offset
*

View File

@ -3714,46 +3714,6 @@ public class Nd4j {
return INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
/**
* @deprecated Use {@link #createFromArray(boolean...)}
*/
@Deprecated
public static INDArray trueVector(boolean[] data) {
return INSTANCE.trueVector(data);
}
/**
* @deprecated Use {@link #createFromArray(long...)}
*/
@Deprecated
public static INDArray trueVector(long[] data) {
return INSTANCE.trueVector(data);
}
/**
* @deprecated Use {@link #createFromArray(int...)}
*/
@Deprecated
public static INDArray trueVector(int[] data) {
return INSTANCE.trueVector(data);
}
/**
* @deprecated Use {@link #createFromArray(float...)}
*/
@Deprecated
public static INDArray trueVector(float[] data) {
return INSTANCE.trueVector(data);
}
/**
* @deprecated Use {@link #createFromArray(double...)}
*/
@Deprecated
public static INDArray trueVector(double[] data) {
return INSTANCE.trueVector(data);
}
/**
* This method creates "empty" INDArray with datatype determined by {@link #dataType()}
*

View File

@ -796,7 +796,7 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
public void testFillOp(){
INDArray ia = Nd4j.trueVector(new double[]{2,2}).castTo(DataType.INT);
INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT);
double value = 42;
INDArray out = Nd4j.create(DataType.FLOAT, 2,2);
OpTestCase op = new OpTestCase(new Fill(ia, out, value));

View File

@ -289,7 +289,7 @@ public class RandomOpValidation extends BaseOpValidation {
@Test
public void testUniformRankSimple() {
INDArray arr = Nd4j.trueVector(new double[]{100.0});
INDArray arr = Nd4j.createFromArray(new double[]{100.0});
// OpTestCase tc = new OpTestCase(DynamicCustomOp.builder("randomuniform")
// .addInputs(arr)
// .addOutputs(Nd4j.createUninitialized(new long[]{100}))
@ -321,7 +321,7 @@ public class RandomOpValidation extends BaseOpValidation {
@Test
public void testRandomExponential() {
long length = 1_000_000;
INDArray shape = Nd4j.trueVector(new double[]{length});
INDArray shape = Nd4j.createFromArray(new double[]{length});
INDArray out = Nd4j.createUninitialized(new long[]{length});
double lambda = 2;
RandomExponential op = new RandomExponential(shape, out, lambda);

View File

@ -133,7 +133,7 @@ public class ByteOrderTests extends BaseNd4jTest {
@Test
public void testVectorEncoding_1() {
val scalar = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5});
val scalar = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5});
FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0);
val fb = scalar.toFlatArray(bufferBuilder);
@ -149,7 +149,7 @@ public class ByteOrderTests extends BaseNd4jTest {
@Test
public void testVectorEncoding_2() {
val scalar = Nd4j.trueVector(new double[]{1, 2, 3, 4, 5});
val scalar = Nd4j.createFromArray(new double[]{1, 2, 3, 4, 5});
FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0);
val fb = scalar.toFlatArray(bufferBuilder);

View File

@ -132,7 +132,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
log.info(graph.asFlatPrint());
val result = graph.execAndEndResult();
val exp = Nd4j.trueVector(new long[]{2, 2, 2});
val exp = Nd4j.createFromArray(new long[]{2, 2, 2});
assertEquals(exp, result);
}

View File

@ -38,7 +38,7 @@ public class TestReverse extends BaseNd4jTest {
@Test
public void testReverse(){
INDArray in = Nd4j.trueVector(new double[]{1,2,3,4,5,6});
INDArray in = Nd4j.createFromArray(new double[]{1,2,3,4,5,6});
INDArray out = Nd4j.create(DataType.DOUBLE, 6);
DynamicCustomOp op = DynamicCustomOp.builder("reverse")
@ -55,7 +55,7 @@ public class TestReverse extends BaseNd4jTest {
@Test
public void testReverse2(){
INDArray in = Nd4j.trueVector(new double[]{1,2,3,4,5,6});
INDArray in = Nd4j.createFromArray(new double[]{1,2,3,4,5,6});
INDArray axis = Nd4j.scalar(0);
INDArray out = Nd4j.create(DataType.DOUBLE, 6);

View File

@ -5899,9 +5899,9 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test
public void testVector_1() {
val vector = Nd4j.trueVector(new float[] {1, 2, 3, 4, 5});
val vector2 = Nd4j.trueVector(new float[] {1, 2, 3, 4, 5});
val vector3 = Nd4j.trueVector(new float[] {1, 2, 3, 4, 6});
val vector = Nd4j.createFromArray(new float[] {1, 2, 3, 4, 5});
val vector2 = Nd4j.createFromArray(new float[] {1, 2, 3, 4, 5});
val vector3 = Nd4j.createFromArray(new float[] {1, 2, 3, 4, 6});
assertFalse(vector.isScalar());
assertEquals(5, vector.length());
@ -5916,9 +5916,9 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test
public void testVectorScalar_2() {
val vector = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5});
val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5});
val scalar = Nd4j.scalar(2.0f);
val exp = Nd4j.trueVector(new float[]{3, 4, 5, 6, 7});
val exp = Nd4j.createFromArray(new float[]{3, 4, 5, 6, 7});
vector.addi(scalar);
@ -5937,7 +5937,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test
public void testReshapeVector() {
val vector = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5, 6});
val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6});
val newShape = vector.reshape(3, 2);
assertEquals(2, newShape.rank());
@ -5946,7 +5946,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test(expected = IllegalStateException.class)
public void testTranspose1() {
val vector = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5, 6});
val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6});
assertArrayEquals(new long[]{6}, vector.shape());
assertArrayEquals(new long[]{1}, vector.stride());
@ -6030,8 +6030,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test
public void testVectorSqueeze() {
val vector = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6}, new long[]{1, 6});
val output = Nd4j.trueVector(new float[] {0, 0, 0, 0, 0, 0});
val exp = Nd4j.trueVector(new float[]{1, 2, 3, 4, 5, 6});
val output = Nd4j.createFromArray(new float[] {0, 0, 0, 0, 0, 0});
val exp = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6});
val op = DynamicCustomOp.builder("squeeze")
.addInputs(vector)
@ -6078,11 +6078,11 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test
public void testVectorScalarConcat() {
val vector = Nd4j.trueVector(new float[] {1, 2});
val vector = Nd4j.createFromArray(new float[] {1, 2});
val scalar = Nd4j.scalar(3.0f);
val output = Nd4j.trueVector(new float[]{0, 0, 0});
val exp = Nd4j.trueVector(new float[]{1, 2, 3});
val output = Nd4j.createFromArray(new float[]{0, 0, 0});
val exp = Nd4j.createFromArray(new float[]{1, 2, 3});
val op = DynamicCustomOp.builder("concat")
.addInputs(vector, scalar)
@ -6103,7 +6103,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test
public void testValueArrayOf_1() {
val vector = Nd4j.valueArrayOf(new long[] {5}, 2f, DataType.FLOAT);
val exp = Nd4j.trueVector(new float[]{2, 2, 2, 2, 2});
val exp = Nd4j.createFromArray(new float[]{2, 2, 2, 2, 2});
assertArrayEquals(exp.shape(), vector.shape());
assertEquals(exp, vector);
@ -6123,7 +6123,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test
public void testArrayCreation() {
val vector = Nd4j.create(new float[]{1, 2, 3}, new long[] {3}, 'c');
val exp = Nd4j.trueVector(new float[]{1, 2, 3});
val exp = Nd4j.createFromArray(new float[]{1, 2, 3});
assertArrayEquals(exp.shape(), vector.shape());
assertEquals(exp, vector);
@ -6964,8 +6964,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray arr = Nd4j.create(new boolean[][]{{false,true,false},{false,false,true},{false,false,true}});
INDArray[] exp = new INDArray[]{
Nd4j.trueVector(new long[]{0,1,2}),
Nd4j.trueVector(new long[]{1,2,2})};
Nd4j.createFromArray(new long[]{0,1,2}),
Nd4j.createFromArray(new long[]{1,2,2})};
INDArray[] act = Nd4j.where(arr, null, null);
@ -6980,9 +6980,9 @@ public class Nd4jTestsC extends BaseNd4jTest {
arr.putScalar(1,2,1,1.0);
arr.putScalar(2,2,1,1.0);
INDArray[] exp = new INDArray[]{
Nd4j.trueVector(new long[]{0,1,2}),
Nd4j.trueVector(new long[]{1,2,2}),
Nd4j.trueVector(new long[]{0,1,1})
Nd4j.createFromArray(new long[]{0,1,2}),
Nd4j.createFromArray(new long[]{1,2,2}),
Nd4j.createFromArray(new long[]{0,1,1})
};
INDArray[] act = Nd4j.where(arr, null, null);

View File

@ -465,7 +465,7 @@ public class ShapeTestsC extends BaseNd4jTest {
@Test
public void testReshapeToTrueScalar_3() {
val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1});
val exp = Nd4j.trueVector(new float[]{1.0f});
val exp = Nd4j.createFromArray(new float[]{1.0f});
assertArrayEquals(new long[]{1, 1}, orig.shape());