From ef8cb55b7b2db7c2e5ec2c4c37cc622e7838fe40 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 1 Jul 2019 19:50:33 -0700 Subject: [PATCH] Nd4j: Change array args to vararg and add toString options (#36) * changed [] to ... Signed-off-by: Ryan Nett * added randn(long seed, int... shape) Signed-off-by: Ryan Nett * Fixed a couple of methods Signed-off-by: Ryan Nett * ToString methods w/ options Signed-off-by: Ryan Nett * fixes, less toString methods, and a few ops I missed Signed-off-by: Ryan Nett * some javadocs, change int... to long... where possible Signed-off-by: Ryan Nett * another javadoc Signed-off-by: Ryan Nett * javadoc fix Signed-off-by: Ryan Nett * just javadoc in INDArray Signed-off-by: Ryan Nett * local/static fix Signed-off-by: Ryan Nett * Add @NonNull to options Signed-off-by: Ryan Nett * javadoc updates/fixes Signed-off-by: Ryan Nett * more @NonNulls Signed-off-by: Ryan Nett * even more @NonNulls, this time on varargs Signed-off-by: Ryan Nett --- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 22 +- .../linalg/api/ndarray/BaseSparseNDArray.java | 23 ++ .../org/nd4j/linalg/api/ndarray/INDArray.java | 29 +- .../java/org/nd4j/linalg/factory/Nd4j.java | 344 +++++++++++------- .../nd4j/linalg/string/NDArrayStrings.java | 54 ++- .../java/org/nd4j/linalg/ToStringTest.java | 65 ++++ .../org/nd4j/linalg/api/buffer/DataType.java | 30 ++ 7 files changed, 419 insertions(+), 148 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 44cf0c233..56c704b03 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -38,7 +38,6 @@ import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.custom.BarnesHutGains; import org.nd4j.linalg.api.ops.impl.reduce.bool.All; import org.nd4j.linalg.api.ops.impl.reduce.bool.Any; import org.nd4j.linalg.api.ops.impl.reduce.floating.*; @@ -5995,13 +5994,30 @@ public abstract class BaseNDArray implements INDArray, Iterable { */ @Override public String toString() { + return toString(new NDArrayStrings()); + } + + + @Override + public String toString(@NonNull NDArrayStrings options){ if (!isCompressed() && !preventUnpack) - return new NDArrayStrings().format(this); + return options.format(this); else if (isCompressed() && compressDebug) return "COMPRESSED ARRAY. SYSTEM PROPERTY compressdebug is true. This is to prevent auto decompression from being triggered."; else if (preventUnpack) return "Array string unpacking is disabled."; - return new NDArrayStrings().format(this); + return options.format(this); + } + + @Override + public String toString(long maxElements, boolean forceSummarize, int precision){ + return toString(new NDArrayStrings(maxElements, forceSummarize, precision)); + } + + + @Override + public String toStringFull(){ + return toString(Long.MAX_VALUE, false, -1 * dataType().precision()); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java index 8dadaaa82..3570ed7ad 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java @@ -39,6 +39,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.SpecifiedIndex; import org.nd4j.linalg.indexing.conditions.Condition; import org.nd4j.linalg.primitives.Pair; +import org.nd4j.linalg.string.NDArrayStrings; import org.nd4j.linalg.util.LinAlgExceptions; import java.nio.LongBuffer; @@ -47,7 +48,9 @@ import java.util.Arrays; import java.util.List; import java.util.NoSuchElementException; +import static org.nd4j.linalg.factory.Nd4j.compressDebug; import static org.nd4j.linalg.factory.Nd4j.createUninitialized; +import static org.nd4j.linalg.factory.Nd4j.preventUnpack; /** * @author Audrey Loeffel @@ -2016,4 +2019,24 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { public INDArray ulike() { throw new UnsupportedOperationException("Not yet implemented"); } + + @Override + public String toString(NDArrayStrings options){ + return "SPARSE ARRAY, TOSTRING NOT SUPPORTED"; + } + + @Override + public String toString(long maxElements, boolean forceSummarize, int decimalPlaces){ + return toString(new NDArrayStrings(maxElements, forceSummarize, decimalPlaces)); + } + + @Override + public String toStringFull(){ + return toString(Long.MAX_VALUE, false, dataType().precision()); + } + + @Override + public String toString(){ + return toString(new NDArrayStrings()); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index 351b4659e..7c547f4af 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -16,7 +16,11 @@ package org.nd4j.linalg.api.ndarray; +import static org.nd4j.linalg.factory.Nd4j.compressDebug; +import static org.nd4j.linalg.factory.Nd4j.preventUnpack; + import com.google.flatbuffers.FlatBufferBuilder; +import lombok.NonNull; import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; @@ -28,6 +32,7 @@ import org.nd4j.linalg.indexing.conditions.Condition; import java.io.Serializable; import java.nio.LongBuffer; import java.util.List; +import org.nd4j.linalg.string.NDArrayStrings; /** * Interface for an ndarray @@ -1914,7 +1919,7 @@ public interface INDArray extends Serializable, AutoCloseable { * @param shape * @param stride */ - public void setShapeAndStride(int[] shape, int[] stride); + void setShapeAndStride(int[] shape, int[] stride); /** * Set the ordering @@ -2843,4 +2848,26 @@ public interface INDArray extends Serializable, AutoCloseable { * @return */ //INDArray[] gains(INDArray input, INDArray gradx, INDArray epsilon); + + /** + * Get a string representation of the array with configurable formatting + * @param options format options + */ + String toString(@NonNull NDArrayStrings options); + + + /** + * Get a string representation of the array + * + * @param maxElements Summarize if more than maxElements in the array + * @param forceSummarize Force a summary instead of a full print + * @param precision The number of decimals to print. Doesn't print trailing 0s if negative + */ + String toString(long maxElements, boolean forceSummarize, int precision); + + /** + * ToString with unlimited elements and precision + * @see org.nd4j.linalg.api.ndarray.BaseNDArray#toString(long, boolean, int) + */ + String toStringFull(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 22bf1d840..c441a10f8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -407,7 +407,7 @@ public class Nd4j { * @param dimension the dimension to do the shuffle * @return */ - public static void shuffle(INDArray toShuffle, Random random, int... dimension) { + public static void shuffle(INDArray toShuffle, Random random, @NonNull int... dimension) { INSTANCE.shuffle(toShuffle, random, dimension); } @@ -418,7 +418,7 @@ public class Nd4j { * @param dimension the dimension to do the shuffle * @return */ - public static void shuffle(INDArray toShuffle, int... dimension) { + public static void shuffle(INDArray toShuffle, @NonNull int... dimension) { //shuffle(toShuffle, new Random(), dimension); INSTANCE.shuffle(toShuffle, new Random(), dimension); } @@ -430,7 +430,7 @@ public class Nd4j { * @param dimension the dimension to do the shuffle * @return */ - public static void shuffle(Collection toShuffle, int... dimension) { + public static void shuffle(Collection toShuffle, @NonNull int... dimension) { //shuffle(toShuffle, new Random(), dimension); INSTANCE.shuffle(toShuffle, new Random(), dimension); } @@ -442,7 +442,7 @@ public class Nd4j { * @param dimension the dimension to do the shuffle * @return */ - public static void shuffle(Collection toShuffle, Random rnd, int... dimension) { + public static void shuffle(Collection toShuffle, Random rnd, @NonNull int... dimension) { //shuffle(toShuffle, new Random(), dimension); INSTANCE.shuffle(toShuffle, rnd, dimension); } @@ -659,7 +659,7 @@ public class Nd4j { * @param dimension the dimension along which to get the maximum * @return array of maximum values. */ - public static INDArray argMax(INDArray arr, int... dimension) { + public static INDArray argMax(INDArray arr, @NonNull int... dimension) { IMax imax = new IMax(arr, dimension); return Nd4j.getExecutioner().exec(imax); } @@ -667,7 +667,7 @@ public class Nd4j { /** * @see #argMax(INDArray, int...) */ - public static INDArray argMin(INDArray arr, int... dimension) { + public static INDArray argMin(INDArray arr, @NonNull int... dimension) { IMin imin = new IMin(arr, dimension); return Nd4j.getExecutioner().exec(imin); } @@ -2325,7 +2325,7 @@ public class Nd4j { * @return the flattened representation of * these ndarrays */ - public static INDArray toFlattened(INDArray... matrices) { + public static INDArray toFlattened(@NonNull INDArray... matrices) { return INSTANCE.toFlattened(matrices); } @@ -2337,7 +2337,7 @@ public class Nd4j { * @return the flattened representation of * these ndarrays */ - public static INDArray toFlattened(char order, INDArray... matrices) { + public static INDArray toFlattened(char order, @NonNull INDArray... matrices) { return INSTANCE.toFlattened(order, matrices); } @@ -3071,7 +3071,7 @@ public class Nd4j { return choice(source, probs, numSamples, Nd4j.getRandom()); } - public static INDArray appendBias(INDArray... vectors) { + public static INDArray appendBias(@NonNull INDArray... vectors) { INDArray ret = INSTANCE.appendBias(vectors); return ret; } @@ -3092,13 +3092,12 @@ public class Nd4j { ////////////////////// RANDOM /////////////////////////////// /** - * Create a random ndarray with the given shape using - * the current time as the seed + * Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape * * @param shape the shape of the array * @return the random ndarray with the specified shape */ - public static INDArray rand(int[] shape) { + public static INDArray rand(@NonNull int... shape) { INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom()); return rand(ret); } @@ -3106,18 +3105,20 @@ public class Nd4j { /** * @see #rand(int[]) */ - public static INDArray rand(long[] shape) { + public static INDArray rand(@NonNull long... shape) { INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom()); return rand(ret); } /** - * Create a random ndarray with given type and shape. - * @param dataType datatype - * @param shape shape - * @return new array. + * Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape and data type + * + * @param shape the shape of the ndarray + * @return the random ndarray with the specified shape */ - public static INDArray rand(DataType dataType, long... shape) { + public static INDArray rand(@NonNull DataType dataType, @NonNull long... shape) { + Preconditions.checkArgument(dataType.isFPType(), + "Can't create a random array of a non-floating point data type"); INDArray ret = createUninitialized(dataType, shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom()); return rand(ret); } @@ -3125,52 +3126,65 @@ public class Nd4j { /** * Create a random ndarray with the given shape and array order * + * Values are sampled from a uniform distribution over (0, 1) + * * @param order the order of the ndarray to return * @param shape the shape of the array * @return the random ndarray with the specified shape */ - public static INDArray rand(char order, int[] shape) { + public static INDArray rand(char order, @NonNull int... shape) { INDArray ret = Nd4j.createUninitialized(shape, order); //INSTANCE.rand(order, shape); return rand(ret); } /** - * Create a random ndarray with the given datatype, order and shape. - * - * The datatype must be one of the floating point types. - * - * @param dataType datatype - * @param order order - * @param shape shape - * @return + * @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.buffer.DataType, char, long...)) */ - public static INDArray rand(DataType dataType, char order, int[] shape) { - INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), order); //INSTANCE.rand(order, shape); - return rand(ret); + @Deprecated + public static INDArray rand(@NonNull DataType dataType, int[] shape, char order) { + return rand(dataType, order, ArrayUtil.toLongArray(shape)); } /** - * @see #rand(DataType, char, int[]) + * @deprecated use {@link org.nd4j.linalg.factory.Nd4j#rand(org.nd4j.linalg.api.buffer.DataType, char, long...)} */ - public static INDArray rand(DataType dataType, int[] shape, char order) { - INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), order); //INSTANCE.rand(order, shape); - return rand(ret); + @Deprecated + public static INDArray rand(@NonNull DataType dataType, char order, @NonNull int... shape) { + return rand(dataType, order, ArrayUtil.toLongArray(shape)); } /** - * Create a random ndarray with the given datatype and shape. - * using the default Nd4j order. + * Create a random ndarray with the given shape, data type, and array order * - * @see #rand(DataType, char, int[]) + * Values are sampled from a uniform distribution over (0, 1) + * + * @param order the order of the ndarray to return + * @param shape the shape of the ndarray + * @param dataType the data type of the ndarray + * @return the random ndarray with the specified shape */ - public static INDArray rand(DataType dataType, int[] shape) { + public static INDArray rand(@NonNull DataType dataType, char order, @NonNull long... shape) { + INDArray ret = Nd4j.createUninitialized(dataType, shape, order); //INSTANCE.rand(order, shape); + return rand(ret); + } + + + /** + * Create a random ndarray with the given shape and data type + * + * Values are sampled from a uniform distribution over (0, 1) + * + * @param shape the shape of the ndarray + * @param dataType the data type of the ndarray + * @return the random ndarray with the specified shape + */ + public static INDArray rand(@NonNull DataType dataType, @NonNull int... shape) { INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order()); //INSTANCE.rand(order, shape); return rand(ret); } /** - * Create a random ndarray with the given shape using - * the current time as the seed + * Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape * * @param rows the number of rows in the matrix * @param columns the number of columns in the matrix @@ -3187,6 +3201,8 @@ public class Nd4j { /** * Create a random ndarray with the given shape and output order * + * Values are sampled from a uniform distribution over (0, 1) + * * @param rows the number of rows in the matrix * @param columns the number of columns in the matrix * @return the random ndarray with the specified shape @@ -3200,20 +3216,31 @@ public class Nd4j { } /** - * Create a random ndarray with the given shape using given seed + * Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape + * using given seed * * @param shape the shape of the array * @param seed the seed to use * @return the random ndarray with the specified shape */ - public static INDArray rand(int[] shape, long seed) { + public static INDArray rand(long seed, @NonNull long... shape) { INDArray ret = createUninitialized(shape, Nd4j.order());//;INSTANCE.rand(shape, seed); return rand(ret, seed); } + /** + * @deprecated use {@link Nd4j#rand(long, long...)} + */ + @Deprecated + public static INDArray rand(int[] shape, long seed) { + INDArray ret = createUninitialized(shape, Nd4j.order());//;INSTANCE.rand(shape, seed); + return rand(seed, ArrayUtil.toLongArray(shape)); + } + /** - * Create a random ndarray with the given shape using the given seed + * Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape + * using the given seed * * @param rows the number of rows in the matrix * @param columns the columns of the ndarray @@ -3225,6 +3252,14 @@ public class Nd4j { return rand(ret, seed); } + /** + * @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.rng.Random, long...)} + */ + @Deprecated + public static INDArray rand(int[] shape, @NonNull org.nd4j.linalg.api.rng.Random rng) { + return rand(rng, ArrayUtil.toLongArray(shape)); + } + /** * Create a random ndarray with the given shape using the given RandomGenerator * @@ -3232,22 +3267,26 @@ public class Nd4j { * @param rng the random generator to use * @return the random ndarray with the specified shape */ - public static INDArray rand(int[] shape, org.nd4j.linalg.api.rng.Random rng) { + public static INDArray rand(@NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) { INDArray ret = createUninitialized(shape, Nd4j.order()); //INSTANCE.rand(shape, rng); return rand(ret, rng); } /** - * Create a random ndarray with the given shape using the given rng - * - * @param shape the shape of the array - * @param dist distribution to use - * @return the random ndarray with the specified shape + * @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.rng.distribution.Distribution, long...)} */ - public static INDArray rand(int[] shape, Distribution dist) { - //INDArray ret = INSTANCE.rand(shape, dist); - //logCreationIfNecessary(ret); - return dist.sample(shape); + @Deprecated + public static INDArray rand(int[] shape, @NonNull Distribution dist) { + return rand(dist, ArrayUtil.toLongArray(shape)); + } + + /** + * @deprecated use + * {@link org.nd4j.linalg.factory.Nd4j#rand(org.nd4j.linalg.api.rng.distribution.Distribution, long...)} + */ + @Deprecated + public static INDArray rand(long[] shape, @NonNull Distribution dist) { + return rand(dist, shape); } /** @@ -3257,7 +3296,7 @@ public class Nd4j { * @param dist distribution to use * @return the random ndarray with the specified shape */ - public static INDArray rand(long[] shape, Distribution dist) { + public static INDArray rand(@NonNull Distribution dist, @NonNull long... shape) { //INDArray ret = INSTANCE.rand(shape, dist); //logCreationIfNecessary(ret); return dist.sample(shape); @@ -3271,21 +3310,24 @@ public class Nd4j { * @param rng the random generator to use * @return the random ndarray with the specified shape */ - public static INDArray rand(int rows, int columns, org.nd4j.linalg.api.rng.Random rng) { + public static INDArray rand(int rows, int columns, @NonNull org.nd4j.linalg.api.rng.Random rng) { INDArray ret = createUninitialized(new int[] {rows, columns}, order());//INSTANCE.rand(rows, columns, rng); return rand(ret, rng); } /** - * Generates a random matrix between min and max - * - * @param shape the number of rows of the matrix - * @param min the minimum number - * @param max the maximum number - * @param rng the rng to use - * @return a random matrix of the specified shape and range + * @deprecated use {@link Nd4j#rand(double, double, org.nd4j.linalg.api.rng.Random, long...)} */ - public static INDArray rand(int[] shape, double min, double max, org.nd4j.linalg.api.rng.Random rng) { + @Deprecated + public static INDArray rand(int[] shape, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) { + return rand(min, max, rng, ArrayUtil.toLongArray(shape)); + } + + /** + * @deprecated use {@link Nd4j#rand(double, double, org.nd4j.linalg.api.rng.Random, long...)} + */ + @Deprecated + public static INDArray rand(long[] shape, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) { INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng); return rand(ret, min, max, rng); } @@ -3299,7 +3341,7 @@ public class Nd4j { * @param rng the rng to use * @return a random matrix of the specified shape and range */ - public static INDArray rand(long[] shape, double min, double max, org.nd4j.linalg.api.rng.Random rng) { + public static INDArray rand(double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) { INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng); return rand(ret, min, max, rng); } @@ -3314,7 +3356,7 @@ public class Nd4j { * @param rng the rng to use * @return a drandom matrix of the specified shape and range */ - public static INDArray rand(int rows, int columns, double min, double max, org.nd4j.linalg.api.rng.Random rng) { + public static INDArray rand(int rows, int columns, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) { INDArray ret = createUninitialized(rows, columns);//INSTANCE.rand(rows, columns, min, max, rng); return rand(ret, min, max, rng); } @@ -3330,34 +3372,47 @@ public class Nd4j { } /** - * Random normal using the current time stamp - * as the seed + * Create a ndarray of the given shape with values from N(0,1) * * @param shape the shape of the array * @return new array with random values */ - public static INDArray randn(int[] shape) { - INDArray ret = Nd4j.createUninitialized(shape, order()); - return randn(ret); + public static INDArray randn(@NonNull int... shape) { + return randn(ArrayUtil.toLongArray(shape)); + } + + + /** + * Create a ndarray of the given shape and data type with values from N(0,1) + * + * @param shape the shape of the ndarray + * @return + */ + public static INDArray randn(@NonNull DataType dataType, @NonNull int... shape) { + return randn(dataType, ArrayUtil.toLongArray(shape)); } /** - * Random normal ndarray of given datatype and shape, + * Create a ndarray of the given shape and data type with values from N(0,1) + * * @param dataType datatype to use, must be a float type datatype. * @param shape shape for the new array. * @return new array with random values */ - public static INDArray randn(DataType dataType, long... shape) { + public static INDArray randn(@NonNull DataType dataType, @NonNull long... shape) { INDArray ret = Nd4j.createUninitialized(dataType, shape, order()); return randn(ret); } + /** - * Random normal ndarray of given shape. defaults to FLOAT and c-order. + * Create a ndarray of the given shape with values from N(0,1). + * Defaults to FLOAT and c-order. + * * @param shape shape for the new array. * @return new array with random values */ - public static INDArray randn(long... shape) { + public static INDArray randn(@NonNull long... shape) { INDArray ret = Nd4j.createUninitialized(shape, order()); return randn(ret); } @@ -3369,7 +3424,7 @@ public class Nd4j { * @param shape the shape of the array * @return new array with random values */ - public static INDArray randn(char order, int[] shape) { + public static INDArray randn(char order, @NonNull int... shape) { INDArray ret = Nd4j.createUninitialized(shape, order); return randn(ret); } @@ -3381,34 +3436,45 @@ public class Nd4j { * @param shape the shape of the array * @return new array with random values */ - public static INDArray randn(char order, long[] shape) { + public static INDArray randn(char order, @NonNull long... shape) { INDArray ret = Nd4j.createUninitialized(shape, order); return randn(ret); } /** - * @see #rand(DataType, char, int[]) + * Random normal N(0,1) with the specified shape and array order + * + * @param order order of the output ndarray + * @param shape the shape of the ndarray + * @param dataType the data type of the ndarray */ - public static INDArray randn(DataType dataType, char order, long[] shape) { + public static INDArray randn(@NonNull DataType dataType, char order, @NonNull long... shape) { INDArray ret = Nd4j.createUninitialized(dataType, shape, order); return randn(ret); } /** - * Random normal using the specified seed + * @deprecated use {@link Nd4j#randn(long, long...)} + */ + @Deprecated + public static INDArray randn(int[] shape, long seed) { + return randn(seed, ArrayUtil.toLongArray(shape)); + } + + /** + * Random normal N(0, 1) using the specified seed * * @param shape the shape of the array * @return */ - public static INDArray randn(int[] shape, long seed) { + public static INDArray randn(long seed, @NonNull long... shape) { INDArray ret = Nd4j.createUninitialized(shape, order()); return randn(ret, seed); } /** - * Random normal using the current time stamp - * as the seed + * Random normal N(0, 1) * * @param rows the number of rows in the matrix * @param columns the number of columns in the matrix @@ -3451,21 +3517,25 @@ public class Nd4j { * @param r the random generator to use * @return */ - public static INDArray randn(long rows, long columns, org.nd4j.linalg.api.rng.Random r) { + public static INDArray randn(long rows, long columns, @NonNull org.nd4j.linalg.api.rng.Random r) { INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order()); return randn(ret, r); } /** - * Random normal using the given rng - * - * @param shape the shape of the array - * @param r the random generator to use - * @return + * @deprecated use {@link Nd4j#randn(org.nd4j.linalg.api.rng.Random, long...)} */ - public static INDArray randn(int[] shape, org.nd4j.linalg.api.rng.Random r) { - final INDArray ret = Nd4j.createUninitialized(shape, order()); - return randn(ret, r); + @Deprecated + public static INDArray randn(int[] shape, @NonNull org.nd4j.linalg.api.rng.Random r) { + return randn(r, ArrayUtil.toLongArray(shape)); + } + + /** + * @deprecated use {@link Nd4j#randn(org.nd4j.linalg.api.rng.Random, long...)} + */ + @Deprecated + public static INDArray randn(long[] shape, @NonNull org.nd4j.linalg.api.rng.Random r) { + return randn(r, shape); } /** @@ -3475,7 +3545,7 @@ public class Nd4j { * @param r the random generator to use * @return */ - public static INDArray randn(long[] shape, org.nd4j.linalg.api.rng.Random r) { + public static INDArray randn(@NonNull org.nd4j.linalg.api.rng.Random r, @NonNull long... shape) { final INDArray ret = Nd4j.createUninitialized(shape, order()); return randn(ret, r); } @@ -3509,7 +3579,7 @@ public class Nd4j { * @param rng the random generator to use * @return the given target array */ - public static INDArray rand(INDArray target, org.nd4j.linalg.api.rng.Random rng) { + public static INDArray rand(INDArray target, @NonNull org.nd4j.linalg.api.rng.Random rng) { return getExecutioner().exec(new UniformDistribution(target), rng); } @@ -3520,7 +3590,7 @@ public class Nd4j { * @param dist distribution to use * @return the random ndarray with the specified shape */ - public static INDArray rand(INDArray target, Distribution dist) { + public static INDArray rand(INDArray target, @NonNull Distribution dist) { return dist.sample(target); } @@ -3533,7 +3603,7 @@ public class Nd4j { * @param rng the random generator to use * @return the given target array */ - public static INDArray rand(INDArray target, double min, double max, org.nd4j.linalg.api.rng.Random rng) { + public static INDArray rand(INDArray target, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) { if (min > max) throw new IllegalArgumentException("the maximum value supplied is smaller than the minimum"); return getExecutioner().exec(new UniformDistribution(target, min, max), rng); @@ -3557,7 +3627,7 @@ public class Nd4j { * @param rng the random generator to use * @return the given target array */ - public static INDArray randn(INDArray target, org.nd4j.linalg.api.rng.Random rng) { + public static INDArray randn(INDArray target, @NonNull org.nd4j.linalg.api.rng.Random rng) { return getExecutioner().exec(new GaussianDistribution(target), rng); } @@ -3569,7 +3639,7 @@ public class Nd4j { * @param shape Shape of the result array * @return Result array */ - public static INDArray randomBernoulli(double p, long... shape) { + public static INDArray randomBernoulli(double p, @NonNull long... shape) { return randomBernoulli(p, Nd4j.createUninitialized(shape)); } @@ -3595,7 +3665,7 @@ public class Nd4j { * @param shape Shape of the result array * @return Result array */ - public static INDArray randomBinomial(int nTrials, double p, long... shape) { + public static INDArray randomBinomial(int nTrials, double p, @NonNull long... shape) { return randomBinomial(nTrials, p, Nd4j.createUninitialized(shape)); } @@ -3757,7 +3827,7 @@ public class Nd4j { * @param shape desired shape of new array. * @return the created ndarray. */ - public static INDArray create(boolean[][] data, long[] shape) { + public static INDArray create(boolean[][] data, @NonNull long... shape) { return INSTANCE.create(ArrayUtil.flatten(data), shape, getStrides(shape), DataType.BOOL, Nd4j.getMemoryManager().getCurrentWorkspace()); } @@ -4083,7 +4153,7 @@ public class Nd4j { * @param shape the shape of the array * @return the created ndarray */ - public static INDArray create(float[] data, int[] shape) { + public static INDArray create(float[] data, int... shape) { if (shape.length == 0 && data.length == 1) { return scalar(data[0]); } @@ -4102,7 +4172,7 @@ public class Nd4j { /** * @see #create(float[], int[]) */ - public static INDArray create(float[] data, long[] shape) { + public static INDArray create(float[] data, long... shape) { if (shape.length == 0 && data.length == 1) { return scalar(data[0]); } @@ -4121,7 +4191,7 @@ public class Nd4j { /** * @see #create(float[], int[]) */ - public static INDArray create(double[] data, long[] shape) { + public static INDArray create(double[] data, long... shape) { if (shape.length == 0 && data.length == 1) { return scalar(data[0]); } @@ -4144,7 +4214,7 @@ public class Nd4j { * @param shape the shape of the array * @return the created ndarray */ - public static INDArray create(double[] data, int[] shape) { + public static INDArray create(double[] data, int... shape) { if (shape.length == 1) { if (shape[0] != data.length) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length); @@ -4346,7 +4416,7 @@ public class Nd4j { * @param shape desired shape of new array. Must match the resulting shape of combining the list. * @return the instance */ - public static INDArray create(List list, int[] shape) { + public static INDArray create(List list, int... shape) { checkShapeValues(shape); INDArray ret = INSTANCE.create(list, shape); @@ -4356,7 +4426,7 @@ public class Nd4j { /** * @see #create(List, int[]) */ - public static INDArray create(List list, long[] shape) { + public static INDArray create(List list, long... shape) { checkShapeValues(shape); INDArray ret = INSTANCE.create(list, shape); @@ -4641,7 +4711,7 @@ public class Nd4j { * @param shape desired shape of new array. * @return the created ndarray. */ - public static INDArray create(DataBuffer data, int[] shape) { + public static INDArray create(DataBuffer data, int... shape) { checkShapeValues(shape); return INSTANCE.create(data, shape); } @@ -4649,7 +4719,7 @@ public class Nd4j { /** * @see #create(DataBuffer, int[]) */ - public static INDArray create(DataBuffer data, long[] shape) { + public static INDArray create(DataBuffer data, long... shape) { checkShapeValues(shape); return INSTANCE.create(data, shape); } @@ -5049,7 +5119,7 @@ public class Nd4j { * * @param shape */ - public static void checkShapeValues(long[] shape) { + public static void checkShapeValues(long... shape) { for (long e: shape) { if (e < 0) throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape) @@ -5062,7 +5132,7 @@ public class Nd4j { * * @param shape */ - public static void checkShapeValues(int[] shape) { + public static void checkShapeValues(int... shape) { for (int e: shape) { if (e < 1) throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape) @@ -5070,7 +5140,7 @@ public class Nd4j { } } - protected static void checkShapeValues(int length, int[] shape) { + protected static void checkShapeValues(int length, int... shape) { checkShapeValues(shape); if (ArrayUtil.prodLong(shape) > length) @@ -5078,7 +5148,7 @@ public class Nd4j { + " doesn't match data length: " + length); } - protected static void checkShapeValues(int length, long[] shape) { + protected static void checkShapeValues(int length, long... shape) { checkShapeValues(shape); if (ArrayUtil.prodLong(shape) > length) @@ -5171,7 +5241,7 @@ public class Nd4j { /** * @see #createUninitialized(long[]) */ - public static INDArray createUninitialized(int[] shape) { + public static INDArray createUninitialized(int... shape) { if(shape.length == 0) return Nd4j.scalar(dataType(), 0.0); checkShapeValues(shape); @@ -5186,7 +5256,7 @@ public class Nd4j { * @param shape the shape of the array * @return the instance */ - public static INDArray createUninitialized(long[] shape) { + public static INDArray createUninitialized(long... shape) { checkShapeValues(shape); //ensure shapes that wind up being scalar end up with the write shape return createUninitialized(shape, Nd4j.order()); @@ -5199,7 +5269,7 @@ public class Nd4j { * @param shape * @return */ - public static INDArray createUninitializedDetached(int[] shape) { + public static INDArray createUninitializedDetached(int... shape) { return createUninitializedDetached(shape, Nd4j.order()); } @@ -5210,7 +5280,7 @@ public class Nd4j { * @param shape * @return */ - public static INDArray createUninitializedDetached(long[] shape) { + public static INDArray createUninitializedDetached(long... shape) { return createUninitializedDetached(shape, Nd4j.order()); } @@ -5448,7 +5518,7 @@ public class Nd4j { * @param shape the shape of the array * @return the created array. */ - public static INDArray zeros(DataType dataType, long... shape) { + public static INDArray zeros(DataType dataType, @NonNull long... shape) { return INSTANCE.create(dataType, shape, 'c', Nd4j.getMemoryManager().getCurrentWorkspace()); } @@ -5621,7 +5691,7 @@ public class Nd4j { * @param shape Shape fo the array * @return the created ndarray */ - public static INDArray ones(DataType dataType, long... shape) { + public static INDArray ones(DataType dataType, @NonNull long... shape) { if(shape.length == 0) return Nd4j.scalar(dataType, 1.0); INDArray ret = INSTANCE.createUninitialized(dataType, shape, Nd4j.order(), Nd4j.getMemoryManager().getCurrentWorkspace()); @@ -5635,8 +5705,9 @@ public class Nd4j { * * @param arrs the first matrix to concat */ - public static INDArray hstack(INDArray... arrs) { - return INSTANCE.hstack(arrs); + public static INDArray hstack(@NonNull INDArray... arrs) { + INDArray ret = INSTANCE.hstack(arrs); + return ret; } /** @@ -5656,7 +5727,7 @@ public class Nd4j { * * @param arrs Arrays to vstack */ - public static INDArray vstack(INDArray... arrs) { + public static INDArray vstack(@NonNull INDArray... arrs) { Preconditions.checkState(arrs != null && arrs.length > 0, "No input specified to vstack (null or length 0)"); if(arrs[0].rank() == 1){ //Edge case: vstack rank 1 arrays - gives rank 2... vstack([3],[3]) -> [2,3] @@ -5758,7 +5829,7 @@ public class Nd4j { * @param arrays * @return */ - public static INDArray accumulate(INDArray... arrays) { + public static INDArray accumulate(@NonNull INDArray... arrays) { if (arrays == null|| arrays.length == 0) throw new ND4JIllegalStateException("Input for accumulation is null or empty"); @@ -5798,7 +5869,7 @@ public class Nd4j { * @param indexes indexes from source array * @return */ - public static INDArray pullRows(INDArray source, int sourceDimension, int[] indexes) { + public static INDArray pullRows(INDArray source, int sourceDimension, @NonNull int... indexes) { return pullRows(source, sourceDimension, indexes, Nd4j.order()); } @@ -5847,7 +5918,7 @@ public class Nd4j { * @param indexes indexes from source array * @return Destination array with specified tensors */ - public static INDArray pullRows(INDArray source, INDArray destination, int sourceDimension, int[] indexes){ + public static INDArray pullRows(INDArray source, INDArray destination, int sourceDimension, @NonNull int... indexes){ if (sourceDimension >= source.rank()) throw new IllegalStateException("Source dimension can't be higher the rank of source tensor"); @@ -5880,7 +5951,7 @@ public class Nd4j { * @return Output array * @see #concat(int, INDArray...) */ - public static INDArray stack(int axis, INDArray... values){ + public static INDArray stack(int axis, @NonNull INDArray... values){ Preconditions.checkArgument(values != null && values.length > 0, "No inputs: %s", values); Preconditions.checkState(axis >= -(values[0].rank()+1) && axis < values[0].rank()+1, "Invalid axis: must be between " + "%s (inclusive) and %s (exclusive) for rank %s input, got %s", -(values[0].rank()+1), values[0].rank()+1, @@ -5902,7 +5973,7 @@ public class Nd4j { * the ndarray shapes save the dimension shape specified * which is then the sum of the sizes along that dimension */ - public static INDArray concat(int dimension, INDArray... toConcat) { + public static INDArray concat(int dimension, @NonNull INDArray... toConcat) { if(dimension < 0) { dimension += toConcat[0].rank(); } @@ -5919,8 +5990,9 @@ public class Nd4j { * @param toConcat * @return */ - public static INDArray specialConcat(int dimension, INDArray... toConcat) { - return INSTANCE.specialConcat(dimension, toConcat); + public static INDArray specialConcat(int dimension, @NonNull INDArray... toConcat) { + INDArray ret = INSTANCE.specialConcat(dimension, toConcat); + return ret; } /** @@ -5950,7 +6022,7 @@ public class Nd4j { * @param shape the shape of the array * @return an ndarray with ones filled in */ - public static INDArray zeros(int... shape) { + public static INDArray zeros(@NonNull int... shape) { return Nd4j.create(shape); } @@ -5961,7 +6033,7 @@ public class Nd4j { * @param shape the shape of the array * @return an ndarray with ones filled in */ - public static INDArray zeros(long... shape) { + public static INDArray zeros(@NonNull long... shape) { return Nd4j.create(shape); } @@ -6085,7 +6157,7 @@ public class Nd4j { * @return the strides for the given shape * and order specified by NDArrays.order() */ - public static int[] getStrides(int[] shape) { + public static int[] getStrides(@NonNull int... shape) { return getStrides(shape, Nd4j.order()); } @@ -6097,7 +6169,7 @@ public class Nd4j { * @return the strides for the given shape * and order specified by NDArrays.order() */ - public static long[] getStrides(long[] shape) { + public static long[] getStrides(@NonNull long... shape) { return getStrides(shape, Nd4j.order()); } @@ -6108,7 +6180,7 @@ public class Nd4j { * @param repeat the shape to repeat * @return the tiled ndarray */ - public static INDArray tile(INDArray tile, int... repeat) { + public static INDArray tile(INDArray tile, @NonNull int... repeat) { int d = repeat.length; long[] shape = ArrayUtil.copy(tile.shape()); long n = Math.max(tile.length(), 1); @@ -6157,11 +6229,11 @@ public class Nd4j { * @return the strides for the given shape * and order specified by NDArrays.order() */ - public static int[] getComplexStrides(int[] shape) { + public static int[] getComplexStrides(@NonNull int... shape) { return getComplexStrides(shape, Nd4j.order()); } - public static long[] getComplexStrides(long[] shape) { + public static long[] getComplexStrides(@NonNull long... shape) { return getComplexStrides(shape, Nd4j.order()); } @@ -6501,7 +6573,7 @@ public class Nd4j { * * @return */ - public static INDArray pile(INDArray... arrays) { + public static INDArray pile(@NonNull INDArray... arrays) { // if we have vectors as input, it's just vstack use case long[] shape = arrays[0].shape(); @@ -6536,7 +6608,7 @@ public class Nd4j { * @param dimensions * @return */ - public static INDArray[] tear(INDArray tensor, int... dimensions) { + public static INDArray[] tear(INDArray tensor, @NonNull int... dimensions) { if (dimensions.length >= tensor.rank()) throw new ND4JIllegalStateException("Target dimensions number should be less tensor rank"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/string/NDArrayStrings.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/string/NDArrayStrings.java index 2253fdf4f..d6de21130 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/string/NDArrayStrings.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/string/NDArrayStrings.java @@ -40,6 +40,8 @@ import java.util.Locale; * @author Adam Gibson * @author Susan Eraly */ +@Getter +@Setter public class NDArrayStrings { public static final String EMPTY_ARRAY_STR = "[]"; @@ -55,7 +57,7 @@ public class NDArrayStrings { @Setter @Getter private static long maxPrintElements = DEFAULT_MAX_PRINT_ELEMENTS; - + private long localMaxPrintElements = maxPrintElements; private String colSep = ","; private String newLineSep = ","; private int padding = 7; @@ -82,6 +84,38 @@ public class NDArrayStrings { this(",", precision); } + public NDArrayStrings(long maxElements, int precision) { + this(",", precision); + this.localMaxPrintElements = maxElements; + } + + public NDArrayStrings(long maxElements) { + this(); + this.localMaxPrintElements = maxElements; + } + + public NDArrayStrings(long maxElements, boolean forceSummarize, int precision) { + this(",", precision); + if(forceSummarize) + localMaxPrintElements = 0; + else + localMaxPrintElements = maxElements; + } + + public NDArrayStrings(boolean forceSummarize, int precision) { + this(",", precision); + if(forceSummarize) + localMaxPrintElements = 0; + } + + + + public NDArrayStrings(boolean forceSummarize) { + this(",", 4); + if(forceSummarize) + localMaxPrintElements = 0; + } + /** * Specify a delimiter for elements in columns for 2d arrays (or in the rank-1th dimension in higher order arrays) @@ -93,13 +127,17 @@ public class NDArrayStrings { public NDArrayStrings(String colSep, int precision) { this.colSep = colSep; if (!colSep.replaceAll("\\s", "").equals(",")) this.newLineSep = ""; - this.precision = precision; - String decFormatNum = "0."; - while (precision > 0) { - decFormatNum += "0"; - precision -= 1; + StringBuilder decFormatNum = new StringBuilder("0."); + + int prec = Math.abs(precision); + this.precision = prec; + boolean useHash = precision < 0; + + while (prec > 0) { + decFormatNum.append(useHash ? "#" : "0"); + prec -= 1; } - this.decimalFormat = localeIndifferentDecimalFormat(decFormatNum); + this.decimalFormat = localeIndifferentDecimalFormat(decFormatNum.toString()); } /** @@ -147,7 +185,7 @@ public class NDArrayStrings { if (this.scientificFormat.length() + 2 > this.padding) this.padding = this.scientificFormat.length() + 2; this.maxToPrintWithoutSwitching = Math.pow(10,this.precision); this.minToPrintWithoutSwitching = 1.0/(this.maxToPrintWithoutSwitching); - return format(arr, 0, summarize && arr.length() > maxPrintElements); + return format(arr, 0, summarize && arr.length() > localMaxPrintElements); } private String format(INDArray arr, int offset, boolean summarize) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java new file mode 100644 index 000000000..eece027ca --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java @@ -0,0 +1,65 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg; + +import static org.junit.Assert.assertEquals; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; + +@RunWith(Parameterized.class) +@Slf4j +public class ToStringTest extends BaseNd4jTest { + public ToStringTest(Nd4jBackend backend) { + super(backend); + } + + @Test + public void testToString() throws Exception { + assertEquals("[ 1, 2, 3]", + Nd4j.createFromArray(1, 2, 3).toString()); + + assertEquals("[ 1, 2, 3 ... 6 7, 8]", + Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(1000, false, 2)); + + assertEquals("[ 1.132, 2.644, 3.234]", + Nd4j.createFromArray(1.132414, 2.64356456, 3.234234).toString(1000, false, 3)); + + assertEquals("[ 1.132414, 2.64356456, 3.25345234]", + Nd4j.createFromArray(1.132414, 2.64356456, 3.25345234).toStringFull()); + + assertEquals("[ 1, 2, 3 ... 6 7, 8]", + Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(100, true, 1)); + + } + + @Override + public char ordering() { + return 'c'; + } +} diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java index e4976d353..84715f878 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java @@ -124,6 +124,36 @@ public enum DataType { } } + /** + * @return the max number of significant decimal digits + */ + public int precision(){ + switch (this){ + case DOUBLE: + return 17; + case FLOAT: + return 9; + case HALF: + return 5; + case BFLOAT16: + return 4; + case LONG: + case INT: + case SHORT: + case BYTE: + case UBYTE: + case BOOL: + case UTF8: + case COMPRESSED: + case UINT16: + case UINT32: + case UINT64: + case UNKNOWN: + default: + return -1; + } + } + /** * @return For fixed-width types, this returns the number of bytes per array element */