diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 9a28b9155..a5185ef88 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -2554,20 +2554,18 @@ public class Nd4j { * @param dis the data input stream to read from * @return the ndarray */ - public static INDArray read(DataInputStream dis) throws IOException { + public static INDArray read(DataInputStream dis) { val headerShape = BaseDataBuffer.readHeader(dis); - var shapeInformation = Nd4j.createBufferDetached(new long[]{headerShape.getMiddle().longValue()}, headerShape.getRight()); + var shapeInformation = Nd4j.createBufferDetached(new long[]{headerShape.getMiddle()}, headerShape.getRight()); shapeInformation.read(dis, headerShape.getLeft(), headerShape.getMiddle(), headerShape.getThird()); - val length = Shape.length(shapeInformation); - DataType type = null; + DataType type; DataBuffer data = null; val headerData = BaseDataBuffer.readHeader(dis); try { // current version contains dtype in extras data = CompressedDataBuffer.readUnknown(dis, headerData.getFirst(), headerData.getMiddle(), headerData.getRight()); - type = ArrayOptionsHelper.dataType(shapeInformation.asLong()); } catch (ND4JUnknownDataTypeException e) { // manually setting data type type = headerData.getRight(); @@ -2767,23 +2765,9 @@ public class Nd4j { } public static INDArray appendBias(@NonNull INDArray... vectors) { - INDArray ret = INSTANCE.appendBias(vectors); - return ret; + return INSTANCE.appendBias(vectors); } - /** - * Perform an operation along a diagonal - * - * @param x the ndarray to perform the operation on - * @param func the operation to perform - */ - public static void doAlongDiagonal(INDArray x, Function func) { - if (x.isMatrix()) - for (int i = 0; i < x.rows(); i++) - x.put(i, i, func.apply(x.getDouble(i, i))); - } - - ////////////////////// RANDOM /////////////////////////////// /** @@ -2859,7 +2843,7 @@ public class Nd4j { * @return the random ndarray with the specified shape */ public static INDArray rand(@NonNull DataType dataType, char order, @NonNull long... shape) { - INDArray ret = Nd4j.createUninitialized(dataType, shape, order); //INSTANCE.rand(order, shape); + INDArray ret = Nd4j.createUninitialized(dataType, shape, order); return rand(ret); } @@ -2874,7 +2858,7 @@ public class Nd4j { * @return the random ndarray with the specified shape */ public static INDArray rand(@NonNull DataType dataType, @NonNull int... shape) { - INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order()); //INSTANCE.rand(order, shape); + INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order()); return rand(ret); } @@ -2889,7 +2873,7 @@ public class Nd4j { if (rows < 1 || columns < 1) throw new ND4JIllegalStateException("Number of rows and columns should be positive for new INDArray"); - INDArray ret = createUninitialized(new int[] {rows, columns}, Nd4j.order());//INSTANCE.rand(rows, columns, Nd4j.getRandom()); + INDArray ret = createUninitialized(new int[] {rows, columns}, Nd4j.order()); return rand(ret); } @@ -2928,7 +2912,6 @@ public class Nd4j { */ @Deprecated public static INDArray rand(int[] shape, long seed) { - INDArray ret = createUninitialized(shape, Nd4j.order());//;INSTANCE.rand(shape, seed); return rand(seed, ArrayUtil.toLongArray(shape)); } @@ -2992,8 +2975,6 @@ public class Nd4j { * @return the random ndarray with the specified shape */ public static INDArray rand(@NonNull Distribution dist, @NonNull long... shape) { - //INDArray ret = INSTANCE.rand(shape, dist); - //logCreationIfNecessary(ret); return dist.sample(shape); } @@ -3006,7 +2987,7 @@ public class Nd4j { * @return the random ndarray with the specified shape */ public static INDArray rand(int rows, int columns, @NonNull org.nd4j.linalg.api.rng.Random rng) { - INDArray ret = createUninitialized(new int[] {rows, columns}, order());//INSTANCE.rand(rows, columns, rng); + INDArray ret = createUninitialized(new int[] {rows, columns}, order()); return rand(ret, rng); } @@ -3023,7 +3004,7 @@ public class Nd4j { */ @Deprecated public static INDArray rand(long[] shape, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) { - INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng); + INDArray ret = createUninitialized(shape, order()); return rand(ret, min, max, rng); } @@ -3037,7 +3018,7 @@ public class Nd4j { * @return a random matrix of the specified shape and range */ public static INDArray rand(double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) { - INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng); + INDArray ret = createUninitialized(shape, order()); return rand(ret, min, max, rng); } @@ -3052,7 +3033,7 @@ public class Nd4j { * @return a drandom matrix of the specified shape and range */ public static INDArray rand(int rows, int columns, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) { - INDArray ret = createUninitialized(rows, columns);//INSTANCE.rand(rows, columns, min, max, rng); + INDArray ret = createUninitialized(rows, columns); return rand(ret, min, max, rng); } @@ -3081,7 +3062,7 @@ public class Nd4j { * Create a ndarray of the given shape and data type with values from N(0,1) * * @param shape the shape of the ndarray - * @return + * @return new array with random values */ public static INDArray randn(@NonNull DataType dataType, @NonNull int... shape) { return randn(dataType, ArrayUtil.toLongArray(shape)); @@ -3161,7 +3142,7 @@ public class Nd4j { * Random normal N(0, 1) using the specified seed * * @param shape the shape of the array - * @return + * @return new array with random values */ public static INDArray randn(long seed, @NonNull long... shape) { INDArray ret = Nd4j.createUninitialized(shape, order()); @@ -3173,7 +3154,7 @@ public class Nd4j { * * @param rows the number of rows in the matrix * @param columns the number of columns in the matrix - * @return + * @return new array with random values */ public static INDArray randn(long rows, long columns) { INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order()); @@ -3197,7 +3178,7 @@ public class Nd4j { * * @param rows the number of rows in the matrix * @param columns the number of columns in the matrix - * @return + * @return new array with random values */ public static INDArray randn(long rows, long columns, long seed) { INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order()); @@ -3210,7 +3191,7 @@ public class Nd4j { * @param rows the number of rows in the matrix * @param columns the number of columns in the matrix * @param r the random generator to use - * @return + * @return new array with random values */ public static INDArray randn(long rows, long columns, @NonNull org.nd4j.linalg.api.rng.Random r) { INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order()); @@ -3238,7 +3219,7 @@ public class Nd4j { * * @param shape the shape of the array * @param r the random generator to use - * @return + * @return new array with random values */ public static INDArray randn(@NonNull org.nd4j.linalg.api.rng.Random r, @NonNull long... shape) { final INDArray ret = Nd4j.createUninitialized(shape, order()); @@ -3409,9 +3390,9 @@ public class Nd4j { * * PLEASE NOTE: memory of underlying array will be NOT initialized, and won't be set to 0.0 * - * @param rows - * @param columns - * @return + * @param rows rows + * @param columns columns + * @return uninitialized 2D array of rows x columns */ public static INDArray createUninitialized(long rows, long columns) { return createUninitialized(new long[] {rows, columns}); @@ -3532,7 +3513,7 @@ public class Nd4j { * @return the created ndarray. */ public static INDArray create(double[][][] data) { - return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length}); + return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length); } /** @@ -3541,7 +3522,7 @@ public class Nd4j { * @return the created ndarray. */ public static INDArray create(float[][][] data) { - return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length}); + return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length); } /** @@ -3559,7 +3540,7 @@ public class Nd4j { * @return the created ndarray. */ public static INDArray create(double[][][][] data) { - return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length, data[0][0][0].length}); + return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length, data[0][0][0].length); } /** @@ -3568,7 +3549,7 @@ public class Nd4j { * @return the created ndarray. */ public static INDArray create(float[][][][] data) { - return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length, data[0][0][0].length}); + return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length, data[0][0][0].length); } /** @@ -3609,8 +3590,7 @@ public class Nd4j { * @return the created ndarray */ public static INDArray create(float[] data, char order) { - INDArray ret = INSTANCE.create(data, order); - return ret; + return INSTANCE.create(data, order); } /** @@ -3621,8 +3601,7 @@ public class Nd4j { * @return the created ndarray */ public static INDArray create(double[] data, char order) { - INDArray ret = INSTANCE.create(data, order); - return ret; + return INSTANCE.create(data, order); } /** @@ -3633,8 +3612,7 @@ public class Nd4j { * @return the created ndarray */ public static INDArray create(int columns, char order) { - INDArray ret = INSTANCE.create(new long[] {columns}, Nd4j.getStrides(new long[] {columns}, order), 0, order); - return ret; + return INSTANCE.create(new long[] {columns}, Nd4j.getStrides(new long[] {columns}, order), 0, order); } /** @@ -3719,8 +3697,7 @@ public class Nd4j { * @return the created ndarray. */ public static INDArray create(int[] data, long[] shape, long[]strides, char order, DataType type) { - val ret = INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace()); - return ret; + return INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace()); } /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java index b068a1a65..1bf709a5f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java @@ -101,6 +101,22 @@ public class RngTests extends BaseNd4jTest { } + @Test + void testRandomBinomial() { + //silly tests. Just increasing the usage for randomBinomial to stop compiler warnings. + INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3); + assertTrue(x.sum().getDouble(0) > 0.0); //silly test. Just increasing th usage for randomBinomial + + x = Nd4j.randomBinomial(10, 0.5, x); + assertTrue(x.sum().getDouble(0) > 0.0); + + x = Nd4j.randomExponential(0.5, 3,3); + assertTrue(x.sum().getDouble(0) > 0.0); + + x = Nd4j.randomExponential(0.5, x); + assertTrue(x.sum().getDouble(0) > 0.0); + } + @Override public char ordering() {