Nd4j refactoring (#111)
* refactoring Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com>master
parent
c29c011d1a
commit
b10ab239c0
|
@ -2554,20 +2554,18 @@ public class Nd4j {
|
|||
* @param dis the data input stream to read from
|
||||
* @return the ndarray
|
||||
*/
|
||||
public static INDArray read(DataInputStream dis) throws IOException {
|
||||
public static INDArray read(DataInputStream dis) {
|
||||
val headerShape = BaseDataBuffer.readHeader(dis);
|
||||
|
||||
var shapeInformation = Nd4j.createBufferDetached(new long[]{headerShape.getMiddle().longValue()}, headerShape.getRight());
|
||||
var shapeInformation = Nd4j.createBufferDetached(new long[]{headerShape.getMiddle()}, headerShape.getRight());
|
||||
shapeInformation.read(dis, headerShape.getLeft(), headerShape.getMiddle(), headerShape.getThird());
|
||||
val length = Shape.length(shapeInformation);
|
||||
DataType type = null;
|
||||
DataType type;
|
||||
DataBuffer data = null;
|
||||
|
||||
val headerData = BaseDataBuffer.readHeader(dis);
|
||||
try {
|
||||
// current version contains dtype in extras
|
||||
data = CompressedDataBuffer.readUnknown(dis, headerData.getFirst(), headerData.getMiddle(), headerData.getRight());
|
||||
type = ArrayOptionsHelper.dataType(shapeInformation.asLong());
|
||||
} catch (ND4JUnknownDataTypeException e) {
|
||||
// manually setting data type
|
||||
type = headerData.getRight();
|
||||
|
@ -2767,23 +2765,9 @@ public class Nd4j {
|
|||
}
|
||||
|
||||
public static INDArray appendBias(@NonNull INDArray... vectors) {
|
||||
INDArray ret = INSTANCE.appendBias(vectors);
|
||||
return ret;
|
||||
return INSTANCE.appendBias(vectors);
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform an operation along a diagonal
|
||||
*
|
||||
* @param x the ndarray to perform the operation on
|
||||
* @param func the operation to perform
|
||||
*/
|
||||
public static void doAlongDiagonal(INDArray x, Function<Number, Number> func) {
|
||||
if (x.isMatrix())
|
||||
for (int i = 0; i < x.rows(); i++)
|
||||
x.put(i, i, func.apply(x.getDouble(i, i)));
|
||||
}
|
||||
|
||||
|
||||
////////////////////// RANDOM ///////////////////////////////
|
||||
|
||||
/**
|
||||
|
@ -2859,7 +2843,7 @@ public class Nd4j {
|
|||
* @return the random ndarray with the specified shape
|
||||
*/
|
||||
public static INDArray rand(@NonNull DataType dataType, char order, @NonNull long... shape) {
|
||||
INDArray ret = Nd4j.createUninitialized(dataType, shape, order); //INSTANCE.rand(order, shape);
|
||||
INDArray ret = Nd4j.createUninitialized(dataType, shape, order);
|
||||
return rand(ret);
|
||||
}
|
||||
|
||||
|
@ -2874,7 +2858,7 @@ public class Nd4j {
|
|||
* @return the random ndarray with the specified shape
|
||||
*/
|
||||
public static INDArray rand(@NonNull DataType dataType, @NonNull int... shape) {
|
||||
INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order()); //INSTANCE.rand(order, shape);
|
||||
INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order());
|
||||
return rand(ret);
|
||||
}
|
||||
|
||||
|
@ -2889,7 +2873,7 @@ public class Nd4j {
|
|||
if (rows < 1 || columns < 1)
|
||||
throw new ND4JIllegalStateException("Number of rows and columns should be positive for new INDArray");
|
||||
|
||||
INDArray ret = createUninitialized(new int[] {rows, columns}, Nd4j.order());//INSTANCE.rand(rows, columns, Nd4j.getRandom());
|
||||
INDArray ret = createUninitialized(new int[] {rows, columns}, Nd4j.order());
|
||||
return rand(ret);
|
||||
}
|
||||
|
||||
|
@ -2928,7 +2912,6 @@ public class Nd4j {
|
|||
*/
|
||||
@Deprecated
|
||||
public static INDArray rand(int[] shape, long seed) {
|
||||
INDArray ret = createUninitialized(shape, Nd4j.order());//;INSTANCE.rand(shape, seed);
|
||||
return rand(seed, ArrayUtil.toLongArray(shape));
|
||||
}
|
||||
|
||||
|
@ -2992,8 +2975,6 @@ public class Nd4j {
|
|||
* @return the random ndarray with the specified shape
|
||||
*/
|
||||
public static INDArray rand(@NonNull Distribution dist, @NonNull long... shape) {
|
||||
//INDArray ret = INSTANCE.rand(shape, dist);
|
||||
//logCreationIfNecessary(ret);
|
||||
return dist.sample(shape);
|
||||
}
|
||||
|
||||
|
@ -3006,7 +2987,7 @@ public class Nd4j {
|
|||
* @return the random ndarray with the specified shape
|
||||
*/
|
||||
public static INDArray rand(int rows, int columns, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||
INDArray ret = createUninitialized(new int[] {rows, columns}, order());//INSTANCE.rand(rows, columns, rng);
|
||||
INDArray ret = createUninitialized(new int[] {rows, columns}, order());
|
||||
return rand(ret, rng);
|
||||
}
|
||||
|
||||
|
@ -3023,7 +3004,7 @@ public class Nd4j {
|
|||
*/
|
||||
@Deprecated
|
||||
public static INDArray rand(long[] shape, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng);
|
||||
INDArray ret = createUninitialized(shape, order());
|
||||
return rand(ret, min, max, rng);
|
||||
}
|
||||
|
||||
|
@ -3037,7 +3018,7 @@ public class Nd4j {
|
|||
* @return a random matrix of the specified shape and range
|
||||
*/
|
||||
public static INDArray rand(double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) {
|
||||
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng);
|
||||
INDArray ret = createUninitialized(shape, order());
|
||||
return rand(ret, min, max, rng);
|
||||
}
|
||||
|
||||
|
@ -3052,7 +3033,7 @@ public class Nd4j {
|
|||
* @return a drandom matrix of the specified shape and range
|
||||
*/
|
||||
public static INDArray rand(int rows, int columns, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||
INDArray ret = createUninitialized(rows, columns);//INSTANCE.rand(rows, columns, min, max, rng);
|
||||
INDArray ret = createUninitialized(rows, columns);
|
||||
return rand(ret, min, max, rng);
|
||||
}
|
||||
|
||||
|
@ -3081,7 +3062,7 @@ public class Nd4j {
|
|||
* Create a ndarray of the given shape and data type with values from N(0,1)
|
||||
*
|
||||
* @param shape the shape of the ndarray
|
||||
* @return
|
||||
* @return new array with random values
|
||||
*/
|
||||
public static INDArray randn(@NonNull DataType dataType, @NonNull int... shape) {
|
||||
return randn(dataType, ArrayUtil.toLongArray(shape));
|
||||
|
@ -3161,7 +3142,7 @@ public class Nd4j {
|
|||
* Random normal N(0, 1) using the specified seed
|
||||
*
|
||||
* @param shape the shape of the array
|
||||
* @return
|
||||
* @return new array with random values
|
||||
*/
|
||||
public static INDArray randn(long seed, @NonNull long... shape) {
|
||||
INDArray ret = Nd4j.createUninitialized(shape, order());
|
||||
|
@ -3173,7 +3154,7 @@ public class Nd4j {
|
|||
*
|
||||
* @param rows the number of rows in the matrix
|
||||
* @param columns the number of columns in the matrix
|
||||
* @return
|
||||
* @return new array with random values
|
||||
*/
|
||||
public static INDArray randn(long rows, long columns) {
|
||||
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order());
|
||||
|
@ -3197,7 +3178,7 @@ public class Nd4j {
|
|||
*
|
||||
* @param rows the number of rows in the matrix
|
||||
* @param columns the number of columns in the matrix
|
||||
* @return
|
||||
* @return new array with random values
|
||||
*/
|
||||
public static INDArray randn(long rows, long columns, long seed) {
|
||||
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order());
|
||||
|
@ -3210,7 +3191,7 @@ public class Nd4j {
|
|||
* @param rows the number of rows in the matrix
|
||||
* @param columns the number of columns in the matrix
|
||||
* @param r the random generator to use
|
||||
* @return
|
||||
* @return new array with random values
|
||||
*/
|
||||
public static INDArray randn(long rows, long columns, @NonNull org.nd4j.linalg.api.rng.Random r) {
|
||||
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order());
|
||||
|
@ -3238,7 +3219,7 @@ public class Nd4j {
|
|||
*
|
||||
* @param shape the shape of the array
|
||||
* @param r the random generator to use
|
||||
* @return
|
||||
* @return new array with random values
|
||||
*/
|
||||
public static INDArray randn(@NonNull org.nd4j.linalg.api.rng.Random r, @NonNull long... shape) {
|
||||
final INDArray ret = Nd4j.createUninitialized(shape, order());
|
||||
|
@ -3409,9 +3390,9 @@ public class Nd4j {
|
|||
*
|
||||
* PLEASE NOTE: memory of underlying array will be NOT initialized, and won't be set to 0.0
|
||||
*
|
||||
* @param rows
|
||||
* @param columns
|
||||
* @return
|
||||
* @param rows rows
|
||||
* @param columns columns
|
||||
* @return uninitialized 2D array of rows x columns
|
||||
*/
|
||||
public static INDArray createUninitialized(long rows, long columns) {
|
||||
return createUninitialized(new long[] {rows, columns});
|
||||
|
@ -3532,7 +3513,7 @@ public class Nd4j {
|
|||
* @return the created ndarray.
|
||||
*/
|
||||
public static INDArray create(double[][][] data) {
|
||||
return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length});
|
||||
return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -3541,7 +3522,7 @@ public class Nd4j {
|
|||
* @return the created ndarray.
|
||||
*/
|
||||
public static INDArray create(float[][][] data) {
|
||||
return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length});
|
||||
return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -3559,7 +3540,7 @@ public class Nd4j {
|
|||
* @return the created ndarray.
|
||||
*/
|
||||
public static INDArray create(double[][][][] data) {
|
||||
return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length, data[0][0][0].length});
|
||||
return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length, data[0][0][0].length);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -3568,7 +3549,7 @@ public class Nd4j {
|
|||
* @return the created ndarray.
|
||||
*/
|
||||
public static INDArray create(float[][][][] data) {
|
||||
return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length, data[0][0][0].length});
|
||||
return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length, data[0][0][0].length);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -3609,8 +3590,7 @@ public class Nd4j {
|
|||
* @return the created ndarray
|
||||
*/
|
||||
public static INDArray create(float[] data, char order) {
|
||||
INDArray ret = INSTANCE.create(data, order);
|
||||
return ret;
|
||||
return INSTANCE.create(data, order);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -3621,8 +3601,7 @@ public class Nd4j {
|
|||
* @return the created ndarray
|
||||
*/
|
||||
public static INDArray create(double[] data, char order) {
|
||||
INDArray ret = INSTANCE.create(data, order);
|
||||
return ret;
|
||||
return INSTANCE.create(data, order);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -3633,8 +3612,7 @@ public class Nd4j {
|
|||
* @return the created ndarray
|
||||
*/
|
||||
public static INDArray create(int columns, char order) {
|
||||
INDArray ret = INSTANCE.create(new long[] {columns}, Nd4j.getStrides(new long[] {columns}, order), 0, order);
|
||||
return ret;
|
||||
return INSTANCE.create(new long[] {columns}, Nd4j.getStrides(new long[] {columns}, order), 0, order);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -3719,8 +3697,7 @@ public class Nd4j {
|
|||
* @return the created ndarray.
|
||||
*/
|
||||
public static INDArray create(int[] data, long[] shape, long[]strides, char order, DataType type) {
|
||||
val ret = INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
return ret;
|
||||
return INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -101,6 +101,22 @@ public class RngTests extends BaseNd4jTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
void testRandomBinomial() {
|
||||
//silly tests. Just increasing the usage for randomBinomial to stop compiler warnings.
|
||||
INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3);
|
||||
assertTrue(x.sum().getDouble(0) > 0.0); //silly test. Just increasing th usage for randomBinomial
|
||||
|
||||
x = Nd4j.randomBinomial(10, 0.5, x);
|
||||
assertTrue(x.sum().getDouble(0) > 0.0);
|
||||
|
||||
x = Nd4j.randomExponential(0.5, 3,3);
|
||||
assertTrue(x.sum().getDouble(0) > 0.0);
|
||||
|
||||
x = Nd4j.randomExponential(0.5, x);
|
||||
assertTrue(x.sum().getDouble(0) > 0.0);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
|
Loading…
Reference in New Issue