Nd4j refactoring (#111)

* refactoring

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-08-13 20:44:40 +09:00 committed by Alex Black
parent c29c011d1a
commit b10ab239c0
2 changed files with 44 additions and 51 deletions

View File

@ -2554,20 +2554,18 @@ public class Nd4j {
* @param dis the data input stream to read from * @param dis the data input stream to read from
* @return the ndarray * @return the ndarray
*/ */
public static INDArray read(DataInputStream dis) throws IOException { public static INDArray read(DataInputStream dis) {
val headerShape = BaseDataBuffer.readHeader(dis); val headerShape = BaseDataBuffer.readHeader(dis);
var shapeInformation = Nd4j.createBufferDetached(new long[]{headerShape.getMiddle().longValue()}, headerShape.getRight()); var shapeInformation = Nd4j.createBufferDetached(new long[]{headerShape.getMiddle()}, headerShape.getRight());
shapeInformation.read(dis, headerShape.getLeft(), headerShape.getMiddle(), headerShape.getThird()); shapeInformation.read(dis, headerShape.getLeft(), headerShape.getMiddle(), headerShape.getThird());
val length = Shape.length(shapeInformation); DataType type;
DataType type = null;
DataBuffer data = null; DataBuffer data = null;
val headerData = BaseDataBuffer.readHeader(dis); val headerData = BaseDataBuffer.readHeader(dis);
try { try {
// current version contains dtype in extras // current version contains dtype in extras
data = CompressedDataBuffer.readUnknown(dis, headerData.getFirst(), headerData.getMiddle(), headerData.getRight()); data = CompressedDataBuffer.readUnknown(dis, headerData.getFirst(), headerData.getMiddle(), headerData.getRight());
type = ArrayOptionsHelper.dataType(shapeInformation.asLong());
} catch (ND4JUnknownDataTypeException e) { } catch (ND4JUnknownDataTypeException e) {
// manually setting data type // manually setting data type
type = headerData.getRight(); type = headerData.getRight();
@ -2767,23 +2765,9 @@ public class Nd4j {
} }
public static INDArray appendBias(@NonNull INDArray... vectors) { public static INDArray appendBias(@NonNull INDArray... vectors) {
INDArray ret = INSTANCE.appendBias(vectors); return INSTANCE.appendBias(vectors);
return ret;
} }
/**
* Perform an operation along a diagonal
*
* @param x the ndarray to perform the operation on
* @param func the operation to perform
*/
public static void doAlongDiagonal(INDArray x, Function<Number, Number> func) {
if (x.isMatrix())
for (int i = 0; i < x.rows(); i++)
x.put(i, i, func.apply(x.getDouble(i, i)));
}
////////////////////// RANDOM /////////////////////////////// ////////////////////// RANDOM ///////////////////////////////
/** /**
@ -2859,7 +2843,7 @@ public class Nd4j {
* @return the random ndarray with the specified shape * @return the random ndarray with the specified shape
*/ */
public static INDArray rand(@NonNull DataType dataType, char order, @NonNull long... shape) { public static INDArray rand(@NonNull DataType dataType, char order, @NonNull long... shape) {
INDArray ret = Nd4j.createUninitialized(dataType, shape, order); //INSTANCE.rand(order, shape); INDArray ret = Nd4j.createUninitialized(dataType, shape, order);
return rand(ret); return rand(ret);
} }
@ -2874,7 +2858,7 @@ public class Nd4j {
* @return the random ndarray with the specified shape * @return the random ndarray with the specified shape
*/ */
public static INDArray rand(@NonNull DataType dataType, @NonNull int... shape) { public static INDArray rand(@NonNull DataType dataType, @NonNull int... shape) {
INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order()); //INSTANCE.rand(order, shape); INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order());
return rand(ret); return rand(ret);
} }
@ -2889,7 +2873,7 @@ public class Nd4j {
if (rows < 1 || columns < 1) if (rows < 1 || columns < 1)
throw new ND4JIllegalStateException("Number of rows and columns should be positive for new INDArray"); throw new ND4JIllegalStateException("Number of rows and columns should be positive for new INDArray");
INDArray ret = createUninitialized(new int[] {rows, columns}, Nd4j.order());//INSTANCE.rand(rows, columns, Nd4j.getRandom()); INDArray ret = createUninitialized(new int[] {rows, columns}, Nd4j.order());
return rand(ret); return rand(ret);
} }
@ -2928,7 +2912,6 @@ public class Nd4j {
*/ */
@Deprecated @Deprecated
public static INDArray rand(int[] shape, long seed) { public static INDArray rand(int[] shape, long seed) {
INDArray ret = createUninitialized(shape, Nd4j.order());//;INSTANCE.rand(shape, seed);
return rand(seed, ArrayUtil.toLongArray(shape)); return rand(seed, ArrayUtil.toLongArray(shape));
} }
@ -2992,8 +2975,6 @@ public class Nd4j {
* @return the random ndarray with the specified shape * @return the random ndarray with the specified shape
*/ */
public static INDArray rand(@NonNull Distribution dist, @NonNull long... shape) { public static INDArray rand(@NonNull Distribution dist, @NonNull long... shape) {
//INDArray ret = INSTANCE.rand(shape, dist);
//logCreationIfNecessary(ret);
return dist.sample(shape); return dist.sample(shape);
} }
@ -3006,7 +2987,7 @@ public class Nd4j {
* @return the random ndarray with the specified shape * @return the random ndarray with the specified shape
*/ */
public static INDArray rand(int rows, int columns, @NonNull org.nd4j.linalg.api.rng.Random rng) { public static INDArray rand(int rows, int columns, @NonNull org.nd4j.linalg.api.rng.Random rng) {
INDArray ret = createUninitialized(new int[] {rows, columns}, order());//INSTANCE.rand(rows, columns, rng); INDArray ret = createUninitialized(new int[] {rows, columns}, order());
return rand(ret, rng); return rand(ret, rng);
} }
@ -3023,7 +3004,7 @@ public class Nd4j {
*/ */
@Deprecated @Deprecated
public static INDArray rand(long[] shape, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) { public static INDArray rand(long[] shape, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) {
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng); INDArray ret = createUninitialized(shape, order());
return rand(ret, min, max, rng); return rand(ret, min, max, rng);
} }
@ -3037,7 +3018,7 @@ public class Nd4j {
* @return a random matrix of the specified shape and range * @return a random matrix of the specified shape and range
*/ */
public static INDArray rand(double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) { public static INDArray rand(double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) {
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng); INDArray ret = createUninitialized(shape, order());
return rand(ret, min, max, rng); return rand(ret, min, max, rng);
} }
@ -3052,7 +3033,7 @@ public class Nd4j {
* @return a drandom matrix of the specified shape and range * @return a drandom matrix of the specified shape and range
*/ */
public static INDArray rand(int rows, int columns, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) { public static INDArray rand(int rows, int columns, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) {
INDArray ret = createUninitialized(rows, columns);//INSTANCE.rand(rows, columns, min, max, rng); INDArray ret = createUninitialized(rows, columns);
return rand(ret, min, max, rng); return rand(ret, min, max, rng);
} }
@ -3081,7 +3062,7 @@ public class Nd4j {
* Create a ndarray of the given shape and data type with values from N(0,1) * Create a ndarray of the given shape and data type with values from N(0,1)
* *
* @param shape the shape of the ndarray * @param shape the shape of the ndarray
* @return * @return new array with random values
*/ */
public static INDArray randn(@NonNull DataType dataType, @NonNull int... shape) { public static INDArray randn(@NonNull DataType dataType, @NonNull int... shape) {
return randn(dataType, ArrayUtil.toLongArray(shape)); return randn(dataType, ArrayUtil.toLongArray(shape));
@ -3161,7 +3142,7 @@ public class Nd4j {
* Random normal N(0, 1) using the specified seed * Random normal N(0, 1) using the specified seed
* *
* @param shape the shape of the array * @param shape the shape of the array
* @return * @return new array with random values
*/ */
public static INDArray randn(long seed, @NonNull long... shape) { public static INDArray randn(long seed, @NonNull long... shape) {
INDArray ret = Nd4j.createUninitialized(shape, order()); INDArray ret = Nd4j.createUninitialized(shape, order());
@ -3173,7 +3154,7 @@ public class Nd4j {
* *
* @param rows the number of rows in the matrix * @param rows the number of rows in the matrix
* @param columns the number of columns in the matrix * @param columns the number of columns in the matrix
* @return * @return new array with random values
*/ */
public static INDArray randn(long rows, long columns) { public static INDArray randn(long rows, long columns) {
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order()); INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order());
@ -3197,7 +3178,7 @@ public class Nd4j {
* *
* @param rows the number of rows in the matrix * @param rows the number of rows in the matrix
* @param columns the number of columns in the matrix * @param columns the number of columns in the matrix
* @return * @return new array with random values
*/ */
public static INDArray randn(long rows, long columns, long seed) { public static INDArray randn(long rows, long columns, long seed) {
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order()); INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order());
@ -3210,7 +3191,7 @@ public class Nd4j {
* @param rows the number of rows in the matrix * @param rows the number of rows in the matrix
* @param columns the number of columns in the matrix * @param columns the number of columns in the matrix
* @param r the random generator to use * @param r the random generator to use
* @return * @return new array with random values
*/ */
public static INDArray randn(long rows, long columns, @NonNull org.nd4j.linalg.api.rng.Random r) { public static INDArray randn(long rows, long columns, @NonNull org.nd4j.linalg.api.rng.Random r) {
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order()); INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order());
@ -3238,7 +3219,7 @@ public class Nd4j {
* *
* @param shape the shape of the array * @param shape the shape of the array
* @param r the random generator to use * @param r the random generator to use
* @return * @return new array with random values
*/ */
public static INDArray randn(@NonNull org.nd4j.linalg.api.rng.Random r, @NonNull long... shape) { public static INDArray randn(@NonNull org.nd4j.linalg.api.rng.Random r, @NonNull long... shape) {
final INDArray ret = Nd4j.createUninitialized(shape, order()); final INDArray ret = Nd4j.createUninitialized(shape, order());
@ -3409,9 +3390,9 @@ public class Nd4j {
* *
* PLEASE NOTE: memory of underlying array will be NOT initialized, and won't be set to 0.0 * PLEASE NOTE: memory of underlying array will be NOT initialized, and won't be set to 0.0
* *
* @param rows * @param rows rows
* @param columns * @param columns columns
* @return * @return uninitialized 2D array of rows x columns
*/ */
public static INDArray createUninitialized(long rows, long columns) { public static INDArray createUninitialized(long rows, long columns) {
return createUninitialized(new long[] {rows, columns}); return createUninitialized(new long[] {rows, columns});
@ -3532,7 +3513,7 @@ public class Nd4j {
* @return the created ndarray. * @return the created ndarray.
*/ */
public static INDArray create(double[][][] data) { public static INDArray create(double[][][] data) {
return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length}); return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length);
} }
/** /**
@ -3541,7 +3522,7 @@ public class Nd4j {
* @return the created ndarray. * @return the created ndarray.
*/ */
public static INDArray create(float[][][] data) { public static INDArray create(float[][][] data) {
return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length}); return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length);
} }
/** /**
@ -3559,7 +3540,7 @@ public class Nd4j {
* @return the created ndarray. * @return the created ndarray.
*/ */
public static INDArray create(double[][][][] data) { public static INDArray create(double[][][][] data) {
return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length, data[0][0][0].length}); return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length, data[0][0][0].length);
} }
/** /**
@ -3568,7 +3549,7 @@ public class Nd4j {
* @return the created ndarray. * @return the created ndarray.
*/ */
public static INDArray create(float[][][][] data) { public static INDArray create(float[][][][] data) {
return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length, data[0][0][0].length}); return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length, data[0][0][0].length);
} }
/** /**
@ -3609,8 +3590,7 @@ public class Nd4j {
* @return the created ndarray * @return the created ndarray
*/ */
public static INDArray create(float[] data, char order) { public static INDArray create(float[] data, char order) {
INDArray ret = INSTANCE.create(data, order); return INSTANCE.create(data, order);
return ret;
} }
/** /**
@ -3621,8 +3601,7 @@ public class Nd4j {
* @return the created ndarray * @return the created ndarray
*/ */
public static INDArray create(double[] data, char order) { public static INDArray create(double[] data, char order) {
INDArray ret = INSTANCE.create(data, order); return INSTANCE.create(data, order);
return ret;
} }
/** /**
@ -3633,8 +3612,7 @@ public class Nd4j {
* @return the created ndarray * @return the created ndarray
*/ */
public static INDArray create(int columns, char order) { public static INDArray create(int columns, char order) {
INDArray ret = INSTANCE.create(new long[] {columns}, Nd4j.getStrides(new long[] {columns}, order), 0, order); return INSTANCE.create(new long[] {columns}, Nd4j.getStrides(new long[] {columns}, order), 0, order);
return ret;
} }
/** /**
@ -3719,8 +3697,7 @@ public class Nd4j {
* @return the created ndarray. * @return the created ndarray.
*/ */
public static INDArray create(int[] data, long[] shape, long[]strides, char order, DataType type) { public static INDArray create(int[] data, long[] shape, long[]strides, char order, DataType type) {
val ret = INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace()); return INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace());
return ret;
} }
/** /**

View File

@ -101,6 +101,22 @@ public class RngTests extends BaseNd4jTest {
} }
@Test
void testRandomBinomial() {
//silly tests. Just increasing the usage for randomBinomial to stop compiler warnings.
INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3);
assertTrue(x.sum().getDouble(0) > 0.0); //silly test. Just increasing th usage for randomBinomial
x = Nd4j.randomBinomial(10, 0.5, x);
assertTrue(x.sum().getDouble(0) > 0.0);
x = Nd4j.randomExponential(0.5, 3,3);
assertTrue(x.sum().getDouble(0) > 0.0);
x = Nd4j.randomExponential(0.5, x);
assertTrue(x.sum().getDouble(0) > 0.0);
}
@Override @Override
public char ordering() { public char ordering() {