Nd4j: Change array args to vararg and add toString options (#36)

* changed [] to ...

Signed-off-by: Ryan Nett <rnett@skymind.io>

* added randn(long seed, int... shape)

Signed-off-by: Ryan Nett <rnett@skymind.io>

* Fixed a couple of methods

Signed-off-by: Ryan Nett <rnett@skymind.io>

* ToString methods w/ options

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes, less toString methods, and a few ops I missed

Signed-off-by: Ryan Nett <rnett@skymind.io>

* some javadocs, change int... to long... where possible

Signed-off-by: Ryan Nett <rnett@skymind.io>

* another javadoc

Signed-off-by: Ryan Nett <rnett@skymind.io>

* javadoc fix

Signed-off-by: Ryan Nett <rnett@skymind.io>

* just javadoc in INDArray

Signed-off-by: Ryan Nett <rnett@skymind.io>

* local/static fix

Signed-off-by: Ryan Nett <rnett@skymind.io>

* Add @NonNull to options

Signed-off-by: Ryan Nett <rnett@skymind.io>

* javadoc updates/fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* more @NonNulls

Signed-off-by: Ryan Nett <rnett@skymind.io>

* even more @NonNulls, this time on varargs

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-07-01 19:50:33 -07:00 committed by AlexDBlack
parent 366d850f5e
commit ef8cb55b7b
7 changed files with 419 additions and 148 deletions

View File

@ -38,7 +38,6 @@ import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
import org.nd4j.linalg.api.ops.impl.reduce.floating.*;
@ -5995,13 +5994,30 @@ public abstract class BaseNDArray implements INDArray, Iterable {
*/
@Override
public String toString() {
return toString(new NDArrayStrings());
}
@Override
public String toString(@NonNull NDArrayStrings options){
if (!isCompressed() && !preventUnpack)
return new NDArrayStrings().format(this);
return options.format(this);
else if (isCompressed() && compressDebug)
return "COMPRESSED ARRAY. SYSTEM PROPERTY compressdebug is true. This is to prevent auto decompression from being triggered.";
else if (preventUnpack)
return "Array string unpacking is disabled.";
return new NDArrayStrings().format(this);
return options.format(this);
}
@Override
public String toString(long maxElements, boolean forceSummarize, int precision){
return toString(new NDArrayStrings(maxElements, forceSummarize, precision));
}
@Override
public String toStringFull(){
return toString(Long.MAX_VALUE, false, -1 * dataType().precision());
}
/**

View File

@ -39,6 +39,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.string.NDArrayStrings;
import org.nd4j.linalg.util.LinAlgExceptions;
import java.nio.LongBuffer;
@ -47,7 +48,9 @@ import java.util.Arrays;
import java.util.List;
import java.util.NoSuchElementException;
import static org.nd4j.linalg.factory.Nd4j.compressDebug;
import static org.nd4j.linalg.factory.Nd4j.createUninitialized;
import static org.nd4j.linalg.factory.Nd4j.preventUnpack;
/**
* @author Audrey Loeffel
@ -2016,4 +2019,24 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
public INDArray ulike() {
throw new UnsupportedOperationException("Not yet implemented");
}
@Override
public String toString(NDArrayStrings options){
return "SPARSE ARRAY, TOSTRING NOT SUPPORTED";
}
@Override
public String toString(long maxElements, boolean forceSummarize, int decimalPlaces){
return toString(new NDArrayStrings(maxElements, forceSummarize, decimalPlaces));
}
@Override
public String toStringFull(){
return toString(Long.MAX_VALUE, false, dataType().precision());
}
@Override
public String toString(){
return toString(new NDArrayStrings());
}
}

View File

@ -16,7 +16,11 @@
package org.nd4j.linalg.api.ndarray;
import static org.nd4j.linalg.factory.Nd4j.compressDebug;
import static org.nd4j.linalg.factory.Nd4j.preventUnpack;
import com.google.flatbuffers.FlatBufferBuilder;
import lombok.NonNull;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
@ -28,6 +32,7 @@ import org.nd4j.linalg.indexing.conditions.Condition;
import java.io.Serializable;
import java.nio.LongBuffer;
import java.util.List;
import org.nd4j.linalg.string.NDArrayStrings;
/**
* Interface for an ndarray
@ -1914,7 +1919,7 @@ public interface INDArray extends Serializable, AutoCloseable {
* @param shape
* @param stride
*/
public void setShapeAndStride(int[] shape, int[] stride);
void setShapeAndStride(int[] shape, int[] stride);
/**
* Set the ordering
@ -2843,4 +2848,26 @@ public interface INDArray extends Serializable, AutoCloseable {
* @return
*/
//INDArray[] gains(INDArray input, INDArray gradx, INDArray epsilon);
/**
* Get a string representation of the array with configurable formatting
* @param options format options
*/
String toString(@NonNull NDArrayStrings options);
/**
* Get a string representation of the array
*
* @param maxElements Summarize if more than maxElements in the array
* @param forceSummarize Force a summary instead of a full print
* @param precision The number of decimals to print. Doesn't print trailing 0s if negative
*/
String toString(long maxElements, boolean forceSummarize, int precision);
/**
* ToString with unlimited elements and precision
* @see org.nd4j.linalg.api.ndarray.BaseNDArray#toString(long, boolean, int)
*/
String toStringFull();
}

View File

@ -407,7 +407,7 @@ public class Nd4j {
* @param dimension the dimension to do the shuffle
* @return
*/
public static void shuffle(INDArray toShuffle, Random random, int... dimension) {
public static void shuffle(INDArray toShuffle, Random random, @NonNull int... dimension) {
INSTANCE.shuffle(toShuffle, random, dimension);
}
@ -418,7 +418,7 @@ public class Nd4j {
* @param dimension the dimension to do the shuffle
* @return
*/
public static void shuffle(INDArray toShuffle, int... dimension) {
public static void shuffle(INDArray toShuffle, @NonNull int... dimension) {
//shuffle(toShuffle, new Random(), dimension);
INSTANCE.shuffle(toShuffle, new Random(), dimension);
}
@ -430,7 +430,7 @@ public class Nd4j {
* @param dimension the dimension to do the shuffle
* @return
*/
public static void shuffle(Collection<INDArray> toShuffle, int... dimension) {
public static void shuffle(Collection<INDArray> toShuffle, @NonNull int... dimension) {
//shuffle(toShuffle, new Random(), dimension);
INSTANCE.shuffle(toShuffle, new Random(), dimension);
}
@ -442,7 +442,7 @@ public class Nd4j {
* @param dimension the dimension to do the shuffle
* @return
*/
public static void shuffle(Collection<INDArray> toShuffle, Random rnd, int... dimension) {
public static void shuffle(Collection<INDArray> toShuffle, Random rnd, @NonNull int... dimension) {
//shuffle(toShuffle, new Random(), dimension);
INSTANCE.shuffle(toShuffle, rnd, dimension);
}
@ -659,7 +659,7 @@ public class Nd4j {
* @param dimension the dimension along which to get the maximum
* @return array of maximum values.
*/
public static INDArray argMax(INDArray arr, int... dimension) {
public static INDArray argMax(INDArray arr, @NonNull int... dimension) {
IMax imax = new IMax(arr, dimension);
return Nd4j.getExecutioner().exec(imax);
}
@ -667,7 +667,7 @@ public class Nd4j {
/**
* @see #argMax(INDArray, int...)
*/
public static INDArray argMin(INDArray arr, int... dimension) {
public static INDArray argMin(INDArray arr, @NonNull int... dimension) {
IMin imin = new IMin(arr, dimension);
return Nd4j.getExecutioner().exec(imin);
}
@ -2325,7 +2325,7 @@ public class Nd4j {
* @return the flattened representation of
* these ndarrays
*/
public static INDArray toFlattened(INDArray... matrices) {
public static INDArray toFlattened(@NonNull INDArray... matrices) {
return INSTANCE.toFlattened(matrices);
}
@ -2337,7 +2337,7 @@ public class Nd4j {
* @return the flattened representation of
* these ndarrays
*/
public static INDArray toFlattened(char order, INDArray... matrices) {
public static INDArray toFlattened(char order, @NonNull INDArray... matrices) {
return INSTANCE.toFlattened(order, matrices);
}
@ -3071,7 +3071,7 @@ public class Nd4j {
return choice(source, probs, numSamples, Nd4j.getRandom());
}
public static INDArray appendBias(INDArray... vectors) {
public static INDArray appendBias(@NonNull INDArray... vectors) {
INDArray ret = INSTANCE.appendBias(vectors);
return ret;
}
@ -3092,13 +3092,12 @@ public class Nd4j {
////////////////////// RANDOM ///////////////////////////////
/**
* Create a random ndarray with the given shape using
* the current time as the seed
* Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape
*
* @param shape the shape of the array
* @return the random ndarray with the specified shape
*/
public static INDArray rand(int[] shape) {
public static INDArray rand(@NonNull int... shape) {
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom());
return rand(ret);
}
@ -3106,18 +3105,20 @@ public class Nd4j {
/**
* @see #rand(int[])
*/
public static INDArray rand(long[] shape) {
public static INDArray rand(@NonNull long... shape) {
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom());
return rand(ret);
}
/**
* Create a random ndarray with given type and shape.
* @param dataType datatype
* @param shape shape
* @return new array.
* Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape and data type
*
* @param shape the shape of the ndarray
* @return the random ndarray with the specified shape
*/
public static INDArray rand(DataType dataType, long... shape) {
public static INDArray rand(@NonNull DataType dataType, @NonNull long... shape) {
Preconditions.checkArgument(dataType.isFPType(),
"Can't create a random array of a non-floating point data type");
INDArray ret = createUninitialized(dataType, shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom());
return rand(ret);
}
@ -3125,52 +3126,65 @@ public class Nd4j {
/**
* Create a random ndarray with the given shape and array order
*
* Values are sampled from a uniform distribution over (0, 1)
*
* @param order the order of the ndarray to return
* @param shape the shape of the array
* @return the random ndarray with the specified shape
*/
public static INDArray rand(char order, int[] shape) {
public static INDArray rand(char order, @NonNull int... shape) {
INDArray ret = Nd4j.createUninitialized(shape, order); //INSTANCE.rand(order, shape);
return rand(ret);
}
/**
* Create a random ndarray with the given datatype, order and shape.
*
* The datatype must be one of the floating point types.
*
* @param dataType datatype
* @param order order
* @param shape shape
* @return
* @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.buffer.DataType, char, long...))
*/
public static INDArray rand(DataType dataType, char order, int[] shape) {
INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), order); //INSTANCE.rand(order, shape);
return rand(ret);
@Deprecated
public static INDArray rand(@NonNull DataType dataType, int[] shape, char order) {
return rand(dataType, order, ArrayUtil.toLongArray(shape));
}
/**
* @see #rand(DataType, char, int[])
* @deprecated use {@link org.nd4j.linalg.factory.Nd4j#rand(org.nd4j.linalg.api.buffer.DataType, char, long...)}
*/
public static INDArray rand(DataType dataType, int[] shape, char order) {
INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), order); //INSTANCE.rand(order, shape);
return rand(ret);
@Deprecated
public static INDArray rand(@NonNull DataType dataType, char order, @NonNull int... shape) {
return rand(dataType, order, ArrayUtil.toLongArray(shape));
}
/**
* Create a random ndarray with the given datatype and shape.
* using the default Nd4j order.
* Create a random ndarray with the given shape, data type, and array order
*
* @see #rand(DataType, char, int[])
* Values are sampled from a uniform distribution over (0, 1)
*
* @param order the order of the ndarray to return
* @param shape the shape of the ndarray
* @param dataType the data type of the ndarray
* @return the random ndarray with the specified shape
*/
public static INDArray rand(DataType dataType, int[] shape) {
public static INDArray rand(@NonNull DataType dataType, char order, @NonNull long... shape) {
INDArray ret = Nd4j.createUninitialized(dataType, shape, order); //INSTANCE.rand(order, shape);
return rand(ret);
}
/**
* Create a random ndarray with the given shape and data type
*
* Values are sampled from a uniform distribution over (0, 1)
*
* @param shape the shape of the ndarray
* @param dataType the data type of the ndarray
* @return the random ndarray with the specified shape
*/
public static INDArray rand(@NonNull DataType dataType, @NonNull int... shape) {
INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order()); //INSTANCE.rand(order, shape);
return rand(ret);
}
/**
* Create a random ndarray with the given shape using
* the current time as the seed
* Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape
*
* @param rows the number of rows in the matrix
* @param columns the number of columns in the matrix
@ -3187,6 +3201,8 @@ public class Nd4j {
/**
* Create a random ndarray with the given shape and output order
*
* Values are sampled from a uniform distribution over (0, 1)
*
* @param rows the number of rows in the matrix
* @param columns the number of columns in the matrix
* @return the random ndarray with the specified shape
@ -3200,20 +3216,31 @@ public class Nd4j {
}
/**
* Create a random ndarray with the given shape using given seed
* Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape
* using given seed
*
* @param shape the shape of the array
* @param seed the seed to use
* @return the random ndarray with the specified shape
*/
public static INDArray rand(int[] shape, long seed) {
public static INDArray rand(long seed, @NonNull long... shape) {
INDArray ret = createUninitialized(shape, Nd4j.order());//;INSTANCE.rand(shape, seed);
return rand(ret, seed);
}
/**
* @deprecated use {@link Nd4j#rand(long, long...)}
*/
@Deprecated
public static INDArray rand(int[] shape, long seed) {
INDArray ret = createUninitialized(shape, Nd4j.order());//;INSTANCE.rand(shape, seed);
return rand(seed, ArrayUtil.toLongArray(shape));
}
/**
* Create a random ndarray with the given shape using the given seed
* Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape
* using the given seed
*
* @param rows the number of rows in the matrix
* @param columns the columns of the ndarray
@ -3225,6 +3252,14 @@ public class Nd4j {
return rand(ret, seed);
}
/**
* @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.rng.Random, long...)}
*/
@Deprecated
public static INDArray rand(int[] shape, @NonNull org.nd4j.linalg.api.rng.Random rng) {
return rand(rng, ArrayUtil.toLongArray(shape));
}
/**
* Create a random ndarray with the given shape using the given RandomGenerator
*
@ -3232,22 +3267,26 @@ public class Nd4j {
* @param rng the random generator to use
* @return the random ndarray with the specified shape
*/
public static INDArray rand(int[] shape, org.nd4j.linalg.api.rng.Random rng) {
public static INDArray rand(@NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) {
INDArray ret = createUninitialized(shape, Nd4j.order()); //INSTANCE.rand(shape, rng);
return rand(ret, rng);
}
/**
* Create a random ndarray with the given shape using the given rng
*
* @param shape the shape of the array
* @param dist distribution to use
* @return the random ndarray with the specified shape
* @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.rng.distribution.Distribution, long...)}
*/
public static INDArray rand(int[] shape, Distribution dist) {
//INDArray ret = INSTANCE.rand(shape, dist);
//logCreationIfNecessary(ret);
return dist.sample(shape);
@Deprecated
public static INDArray rand(int[] shape, @NonNull Distribution dist) {
return rand(dist, ArrayUtil.toLongArray(shape));
}
/**
* @deprecated use
* {@link org.nd4j.linalg.factory.Nd4j#rand(org.nd4j.linalg.api.rng.distribution.Distribution, long...)}
*/
@Deprecated
public static INDArray rand(long[] shape, @NonNull Distribution dist) {
return rand(dist, shape);
}
/**
@ -3257,7 +3296,7 @@ public class Nd4j {
* @param dist distribution to use
* @return the random ndarray with the specified shape
*/
public static INDArray rand(long[] shape, Distribution dist) {
public static INDArray rand(@NonNull Distribution dist, @NonNull long... shape) {
//INDArray ret = INSTANCE.rand(shape, dist);
//logCreationIfNecessary(ret);
return dist.sample(shape);
@ -3271,21 +3310,24 @@ public class Nd4j {
* @param rng the random generator to use
* @return the random ndarray with the specified shape
*/
public static INDArray rand(int rows, int columns, org.nd4j.linalg.api.rng.Random rng) {
public static INDArray rand(int rows, int columns, @NonNull org.nd4j.linalg.api.rng.Random rng) {
INDArray ret = createUninitialized(new int[] {rows, columns}, order());//INSTANCE.rand(rows, columns, rng);
return rand(ret, rng);
}
/**
* Generates a random matrix between min and max
*
* @param shape the number of rows of the matrix
* @param min the minimum number
* @param max the maximum number
* @param rng the rng to use
* @return a random matrix of the specified shape and range
* @deprecated use {@link Nd4j#rand(double, double, org.nd4j.linalg.api.rng.Random, long...)}
*/
public static INDArray rand(int[] shape, double min, double max, org.nd4j.linalg.api.rng.Random rng) {
@Deprecated
public static INDArray rand(int[] shape, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) {
return rand(min, max, rng, ArrayUtil.toLongArray(shape));
}
/**
* @deprecated use {@link Nd4j#rand(double, double, org.nd4j.linalg.api.rng.Random, long...)}
*/
@Deprecated
public static INDArray rand(long[] shape, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) {
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng);
return rand(ret, min, max, rng);
}
@ -3299,7 +3341,7 @@ public class Nd4j {
* @param rng the rng to use
* @return a random matrix of the specified shape and range
*/
public static INDArray rand(long[] shape, double min, double max, org.nd4j.linalg.api.rng.Random rng) {
public static INDArray rand(double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) {
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng);
return rand(ret, min, max, rng);
}
@ -3314,7 +3356,7 @@ public class Nd4j {
* @param rng the rng to use
* @return a drandom matrix of the specified shape and range
*/
public static INDArray rand(int rows, int columns, double min, double max, org.nd4j.linalg.api.rng.Random rng) {
public static INDArray rand(int rows, int columns, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) {
INDArray ret = createUninitialized(rows, columns);//INSTANCE.rand(rows, columns, min, max, rng);
return rand(ret, min, max, rng);
}
@ -3330,34 +3372,47 @@ public class Nd4j {
}
/**
* Random normal using the current time stamp
* as the seed
* Create a ndarray of the given shape with values from N(0,1)
*
* @param shape the shape of the array
* @return new array with random values
*/
public static INDArray randn(int[] shape) {
INDArray ret = Nd4j.createUninitialized(shape, order());
return randn(ret);
public static INDArray randn(@NonNull int... shape) {
return randn(ArrayUtil.toLongArray(shape));
}
/**
* Create a ndarray of the given shape and data type with values from N(0,1)
*
* @param shape the shape of the ndarray
* @return
*/
public static INDArray randn(@NonNull DataType dataType, @NonNull int... shape) {
return randn(dataType, ArrayUtil.toLongArray(shape));
}
/**
* Random normal ndarray of given datatype and shape,
* Create a ndarray of the given shape and data type with values from N(0,1)
*
* @param dataType datatype to use, must be a float type datatype.
* @param shape shape for the new array.
* @return new array with random values
*/
public static INDArray randn(DataType dataType, long... shape) {
public static INDArray randn(@NonNull DataType dataType, @NonNull long... shape) {
INDArray ret = Nd4j.createUninitialized(dataType, shape, order());
return randn(ret);
}
/**
* Random normal ndarray of given shape. defaults to FLOAT and c-order.
* Create a ndarray of the given shape with values from N(0,1).
* Defaults to FLOAT and c-order.
*
* @param shape shape for the new array.
* @return new array with random values
*/
public static INDArray randn(long... shape) {
public static INDArray randn(@NonNull long... shape) {
INDArray ret = Nd4j.createUninitialized(shape, order());
return randn(ret);
}
@ -3369,7 +3424,7 @@ public class Nd4j {
* @param shape the shape of the array
* @return new array with random values
*/
public static INDArray randn(char order, int[] shape) {
public static INDArray randn(char order, @NonNull int... shape) {
INDArray ret = Nd4j.createUninitialized(shape, order);
return randn(ret);
}
@ -3381,34 +3436,45 @@ public class Nd4j {
* @param shape the shape of the array
* @return new array with random values
*/
public static INDArray randn(char order, long[] shape) {
public static INDArray randn(char order, @NonNull long... shape) {
INDArray ret = Nd4j.createUninitialized(shape, order);
return randn(ret);
}
/**
* @see #rand(DataType, char, int[])
* Random normal N(0,1) with the specified shape and array order
*
* @param order order of the output ndarray
* @param shape the shape of the ndarray
* @param dataType the data type of the ndarray
*/
public static INDArray randn(DataType dataType, char order, long[] shape) {
public static INDArray randn(@NonNull DataType dataType, char order, @NonNull long... shape) {
INDArray ret = Nd4j.createUninitialized(dataType, shape, order);
return randn(ret);
}
/**
* Random normal using the specified seed
* @deprecated use {@link Nd4j#randn(long, long...)}
*/
@Deprecated
public static INDArray randn(int[] shape, long seed) {
return randn(seed, ArrayUtil.toLongArray(shape));
}
/**
* Random normal N(0, 1) using the specified seed
*
* @param shape the shape of the array
* @return
*/
public static INDArray randn(int[] shape, long seed) {
public static INDArray randn(long seed, @NonNull long... shape) {
INDArray ret = Nd4j.createUninitialized(shape, order());
return randn(ret, seed);
}
/**
* Random normal using the current time stamp
* as the seed
* Random normal N(0, 1)
*
* @param rows the number of rows in the matrix
* @param columns the number of columns in the matrix
@ -3451,21 +3517,25 @@ public class Nd4j {
* @param r the random generator to use
* @return
*/
public static INDArray randn(long rows, long columns, org.nd4j.linalg.api.rng.Random r) {
public static INDArray randn(long rows, long columns, @NonNull org.nd4j.linalg.api.rng.Random r) {
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order());
return randn(ret, r);
}
/**
* Random normal using the given rng
*
* @param shape the shape of the array
* @param r the random generator to use
* @return
* @deprecated use {@link Nd4j#randn(org.nd4j.linalg.api.rng.Random, long...)}
*/
public static INDArray randn(int[] shape, org.nd4j.linalg.api.rng.Random r) {
final INDArray ret = Nd4j.createUninitialized(shape, order());
return randn(ret, r);
@Deprecated
public static INDArray randn(int[] shape, @NonNull org.nd4j.linalg.api.rng.Random r) {
return randn(r, ArrayUtil.toLongArray(shape));
}
/**
* @deprecated use {@link Nd4j#randn(org.nd4j.linalg.api.rng.Random, long...)}
*/
@Deprecated
public static INDArray randn(long[] shape, @NonNull org.nd4j.linalg.api.rng.Random r) {
return randn(r, shape);
}
/**
@ -3475,7 +3545,7 @@ public class Nd4j {
* @param r the random generator to use
* @return
*/
public static INDArray randn(long[] shape, org.nd4j.linalg.api.rng.Random r) {
public static INDArray randn(@NonNull org.nd4j.linalg.api.rng.Random r, @NonNull long... shape) {
final INDArray ret = Nd4j.createUninitialized(shape, order());
return randn(ret, r);
}
@ -3509,7 +3579,7 @@ public class Nd4j {
* @param rng the random generator to use
* @return the given target array
*/
public static INDArray rand(INDArray target, org.nd4j.linalg.api.rng.Random rng) {
public static INDArray rand(INDArray target, @NonNull org.nd4j.linalg.api.rng.Random rng) {
return getExecutioner().exec(new UniformDistribution(target), rng);
}
@ -3520,7 +3590,7 @@ public class Nd4j {
* @param dist distribution to use
* @return the random ndarray with the specified shape
*/
public static INDArray rand(INDArray target, Distribution dist) {
public static INDArray rand(INDArray target, @NonNull Distribution dist) {
return dist.sample(target);
}
@ -3533,7 +3603,7 @@ public class Nd4j {
* @param rng the random generator to use
* @return the given target array
*/
public static INDArray rand(INDArray target, double min, double max, org.nd4j.linalg.api.rng.Random rng) {
public static INDArray rand(INDArray target, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) {
if (min > max)
throw new IllegalArgumentException("the maximum value supplied is smaller than the minimum");
return getExecutioner().exec(new UniformDistribution(target, min, max), rng);
@ -3557,7 +3627,7 @@ public class Nd4j {
* @param rng the random generator to use
* @return the given target array
*/
public static INDArray randn(INDArray target, org.nd4j.linalg.api.rng.Random rng) {
public static INDArray randn(INDArray target, @NonNull org.nd4j.linalg.api.rng.Random rng) {
return getExecutioner().exec(new GaussianDistribution(target), rng);
}
@ -3569,7 +3639,7 @@ public class Nd4j {
* @param shape Shape of the result array
* @return Result array
*/
public static INDArray randomBernoulli(double p, long... shape) {
public static INDArray randomBernoulli(double p, @NonNull long... shape) {
return randomBernoulli(p, Nd4j.createUninitialized(shape));
}
@ -3595,7 +3665,7 @@ public class Nd4j {
* @param shape Shape of the result array
* @return Result array
*/
public static INDArray randomBinomial(int nTrials, double p, long... shape) {
public static INDArray randomBinomial(int nTrials, double p, @NonNull long... shape) {
return randomBinomial(nTrials, p, Nd4j.createUninitialized(shape));
}
@ -3757,7 +3827,7 @@ public class Nd4j {
* @param shape desired shape of new array.
* @return the created ndarray.
*/
public static INDArray create(boolean[][] data, long[] shape) {
public static INDArray create(boolean[][] data, @NonNull long... shape) {
return INSTANCE.create(ArrayUtil.flatten(data), shape, getStrides(shape), DataType.BOOL, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@ -4083,7 +4153,7 @@ public class Nd4j {
* @param shape the shape of the array
* @return the created ndarray
*/
public static INDArray create(float[] data, int[] shape) {
public static INDArray create(float[] data, int... shape) {
if (shape.length == 0 && data.length == 1) {
return scalar(data[0]);
}
@ -4102,7 +4172,7 @@ public class Nd4j {
/**
* @see #create(float[], int[])
*/
public static INDArray create(float[] data, long[] shape) {
public static INDArray create(float[] data, long... shape) {
if (shape.length == 0 && data.length == 1) {
return scalar(data[0]);
}
@ -4121,7 +4191,7 @@ public class Nd4j {
/**
* @see #create(float[], int[])
*/
public static INDArray create(double[] data, long[] shape) {
public static INDArray create(double[] data, long... shape) {
if (shape.length == 0 && data.length == 1) {
return scalar(data[0]);
}
@ -4144,7 +4214,7 @@ public class Nd4j {
* @param shape the shape of the array
* @return the created ndarray
*/
public static INDArray create(double[] data, int[] shape) {
public static INDArray create(double[] data, int... shape) {
if (shape.length == 1) {
if (shape[0] != data.length)
throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length);
@ -4346,7 +4416,7 @@ public class Nd4j {
* @param shape desired shape of new array. Must match the resulting shape of combining the list.
* @return the instance
*/
public static INDArray create(List<INDArray> list, int[] shape) {
public static INDArray create(List<INDArray> list, int... shape) {
checkShapeValues(shape);
INDArray ret = INSTANCE.create(list, shape);
@ -4356,7 +4426,7 @@ public class Nd4j {
/**
* @see #create(List, int[])
*/
public static INDArray create(List<INDArray> list, long[] shape) {
public static INDArray create(List<INDArray> list, long... shape) {
checkShapeValues(shape);
INDArray ret = INSTANCE.create(list, shape);
@ -4641,7 +4711,7 @@ public class Nd4j {
* @param shape desired shape of new array.
* @return the created ndarray.
*/
public static INDArray create(DataBuffer data, int[] shape) {
public static INDArray create(DataBuffer data, int... shape) {
checkShapeValues(shape);
return INSTANCE.create(data, shape);
}
@ -4649,7 +4719,7 @@ public class Nd4j {
/**
* @see #create(DataBuffer, int[])
*/
public static INDArray create(DataBuffer data, long[] shape) {
public static INDArray create(DataBuffer data, long... shape) {
checkShapeValues(shape);
return INSTANCE.create(data, shape);
}
@ -5049,7 +5119,7 @@ public class Nd4j {
*
* @param shape
*/
public static void checkShapeValues(long[] shape) {
public static void checkShapeValues(long... shape) {
for (long e: shape) {
if (e < 0)
throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape)
@ -5062,7 +5132,7 @@ public class Nd4j {
*
* @param shape
*/
public static void checkShapeValues(int[] shape) {
public static void checkShapeValues(int... shape) {
for (int e: shape) {
if (e < 1)
throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape)
@ -5070,7 +5140,7 @@ public class Nd4j {
}
}
protected static void checkShapeValues(int length, int[] shape) {
protected static void checkShapeValues(int length, int... shape) {
checkShapeValues(shape);
if (ArrayUtil.prodLong(shape) > length)
@ -5078,7 +5148,7 @@ public class Nd4j {
+ " doesn't match data length: " + length);
}
protected static void checkShapeValues(int length, long[] shape) {
protected static void checkShapeValues(int length, long... shape) {
checkShapeValues(shape);
if (ArrayUtil.prodLong(shape) > length)
@ -5171,7 +5241,7 @@ public class Nd4j {
/**
* @see #createUninitialized(long[])
*/
public static INDArray createUninitialized(int[] shape) {
public static INDArray createUninitialized(int... shape) {
if(shape.length == 0)
return Nd4j.scalar(dataType(), 0.0);
checkShapeValues(shape);
@ -5186,7 +5256,7 @@ public class Nd4j {
* @param shape the shape of the array
* @return the instance
*/
public static INDArray createUninitialized(long[] shape) {
public static INDArray createUninitialized(long... shape) {
checkShapeValues(shape);
//ensure shapes that wind up being scalar end up with the write shape
return createUninitialized(shape, Nd4j.order());
@ -5199,7 +5269,7 @@ public class Nd4j {
* @param shape
* @return
*/
public static INDArray createUninitializedDetached(int[] shape) {
public static INDArray createUninitializedDetached(int... shape) {
return createUninitializedDetached(shape, Nd4j.order());
}
@ -5210,7 +5280,7 @@ public class Nd4j {
* @param shape
* @return
*/
public static INDArray createUninitializedDetached(long[] shape) {
public static INDArray createUninitializedDetached(long... shape) {
return createUninitializedDetached(shape, Nd4j.order());
}
@ -5448,7 +5518,7 @@ public class Nd4j {
* @param shape the shape of the array
* @return the created array.
*/
public static INDArray zeros(DataType dataType, long... shape) {
public static INDArray zeros(DataType dataType, @NonNull long... shape) {
return INSTANCE.create(dataType, shape, 'c', Nd4j.getMemoryManager().getCurrentWorkspace());
}
@ -5621,7 +5691,7 @@ public class Nd4j {
* @param shape Shape fo the array
* @return the created ndarray
*/
public static INDArray ones(DataType dataType, long... shape) {
public static INDArray ones(DataType dataType, @NonNull long... shape) {
if(shape.length == 0)
return Nd4j.scalar(dataType, 1.0);
INDArray ret = INSTANCE.createUninitialized(dataType, shape, Nd4j.order(), Nd4j.getMemoryManager().getCurrentWorkspace());
@ -5635,8 +5705,9 @@ public class Nd4j {
*
* @param arrs the first matrix to concat
*/
public static INDArray hstack(INDArray... arrs) {
return INSTANCE.hstack(arrs);
public static INDArray hstack(@NonNull INDArray... arrs) {
INDArray ret = INSTANCE.hstack(arrs);
return ret;
}
/**
@ -5656,7 +5727,7 @@ public class Nd4j {
*
* @param arrs Arrays to vstack
*/
public static INDArray vstack(INDArray... arrs) {
public static INDArray vstack(@NonNull INDArray... arrs) {
Preconditions.checkState(arrs != null && arrs.length > 0, "No input specified to vstack (null or length 0)");
if(arrs[0].rank() == 1){
//Edge case: vstack rank 1 arrays - gives rank 2... vstack([3],[3]) -> [2,3]
@ -5758,7 +5829,7 @@ public class Nd4j {
* @param arrays
* @return
*/
public static INDArray accumulate(INDArray... arrays) {
public static INDArray accumulate(@NonNull INDArray... arrays) {
if (arrays == null|| arrays.length == 0)
throw new ND4JIllegalStateException("Input for accumulation is null or empty");
@ -5798,7 +5869,7 @@ public class Nd4j {
* @param indexes indexes from source array
* @return
*/
public static INDArray pullRows(INDArray source, int sourceDimension, int[] indexes) {
public static INDArray pullRows(INDArray source, int sourceDimension, @NonNull int... indexes) {
return pullRows(source, sourceDimension, indexes, Nd4j.order());
}
@ -5847,7 +5918,7 @@ public class Nd4j {
* @param indexes indexes from source array
* @return Destination array with specified tensors
*/
public static INDArray pullRows(INDArray source, INDArray destination, int sourceDimension, int[] indexes){
public static INDArray pullRows(INDArray source, INDArray destination, int sourceDimension, @NonNull int... indexes){
if (sourceDimension >= source.rank())
throw new IllegalStateException("Source dimension can't be higher the rank of source tensor");
@ -5880,7 +5951,7 @@ public class Nd4j {
* @return Output array
* @see #concat(int, INDArray...)
*/
public static INDArray stack(int axis, INDArray... values){
public static INDArray stack(int axis, @NonNull INDArray... values){
Preconditions.checkArgument(values != null && values.length > 0, "No inputs: %s", values);
Preconditions.checkState(axis >= -(values[0].rank()+1) && axis < values[0].rank()+1, "Invalid axis: must be between " +
"%s (inclusive) and %s (exclusive) for rank %s input, got %s", -(values[0].rank()+1), values[0].rank()+1,
@ -5902,7 +5973,7 @@ public class Nd4j {
* the ndarray shapes save the dimension shape specified
* which is then the sum of the sizes along that dimension
*/
public static INDArray concat(int dimension, INDArray... toConcat) {
public static INDArray concat(int dimension, @NonNull INDArray... toConcat) {
if(dimension < 0) {
dimension += toConcat[0].rank();
}
@ -5919,8 +5990,9 @@ public class Nd4j {
* @param toConcat
* @return
*/
public static INDArray specialConcat(int dimension, INDArray... toConcat) {
return INSTANCE.specialConcat(dimension, toConcat);
public static INDArray specialConcat(int dimension, @NonNull INDArray... toConcat) {
INDArray ret = INSTANCE.specialConcat(dimension, toConcat);
return ret;
}
/**
@ -5950,7 +6022,7 @@ public class Nd4j {
* @param shape the shape of the array
* @return an ndarray with ones filled in
*/
public static INDArray zeros(int... shape) {
public static INDArray zeros(@NonNull int... shape) {
return Nd4j.create(shape);
}
@ -5961,7 +6033,7 @@ public class Nd4j {
* @param shape the shape of the array
* @return an ndarray with ones filled in
*/
public static INDArray zeros(long... shape) {
public static INDArray zeros(@NonNull long... shape) {
return Nd4j.create(shape);
}
@ -6085,7 +6157,7 @@ public class Nd4j {
* @return the strides for the given shape
* and order specified by NDArrays.order()
*/
public static int[] getStrides(int[] shape) {
public static int[] getStrides(@NonNull int... shape) {
return getStrides(shape, Nd4j.order());
}
@ -6097,7 +6169,7 @@ public class Nd4j {
* @return the strides for the given shape
* and order specified by NDArrays.order()
*/
public static long[] getStrides(long[] shape) {
public static long[] getStrides(@NonNull long... shape) {
return getStrides(shape, Nd4j.order());
}
@ -6108,7 +6180,7 @@ public class Nd4j {
* @param repeat the shape to repeat
* @return the tiled ndarray
*/
public static INDArray tile(INDArray tile, int... repeat) {
public static INDArray tile(INDArray tile, @NonNull int... repeat) {
int d = repeat.length;
long[] shape = ArrayUtil.copy(tile.shape());
long n = Math.max(tile.length(), 1);
@ -6157,11 +6229,11 @@ public class Nd4j {
* @return the strides for the given shape
* and order specified by NDArrays.order()
*/
public static int[] getComplexStrides(int[] shape) {
public static int[] getComplexStrides(@NonNull int... shape) {
return getComplexStrides(shape, Nd4j.order());
}
public static long[] getComplexStrides(long[] shape) {
public static long[] getComplexStrides(@NonNull long... shape) {
return getComplexStrides(shape, Nd4j.order());
}
@ -6501,7 +6573,7 @@ public class Nd4j {
*
* @return
*/
public static INDArray pile(INDArray... arrays) {
public static INDArray pile(@NonNull INDArray... arrays) {
// if we have vectors as input, it's just vstack use case
long[] shape = arrays[0].shape();
@ -6536,7 +6608,7 @@ public class Nd4j {
* @param dimensions
* @return
*/
public static INDArray[] tear(INDArray tensor, int... dimensions) {
public static INDArray[] tear(INDArray tensor, @NonNull int... dimensions) {
if (dimensions.length >= tensor.rank())
throw new ND4JIllegalStateException("Target dimensions number should be less tensor rank");

View File

@ -40,6 +40,8 @@ import java.util.Locale;
* @author Adam Gibson
* @author Susan Eraly
*/
@Getter
@Setter
public class NDArrayStrings {
public static final String EMPTY_ARRAY_STR = "[]";
@ -55,7 +57,7 @@ public class NDArrayStrings {
@Setter @Getter
private static long maxPrintElements = DEFAULT_MAX_PRINT_ELEMENTS;
private long localMaxPrintElements = maxPrintElements;
private String colSep = ",";
private String newLineSep = ",";
private int padding = 7;
@ -82,6 +84,38 @@ public class NDArrayStrings {
this(",", precision);
}
public NDArrayStrings(long maxElements, int precision) {
this(",", precision);
this.localMaxPrintElements = maxElements;
}
public NDArrayStrings(long maxElements) {
this();
this.localMaxPrintElements = maxElements;
}
public NDArrayStrings(long maxElements, boolean forceSummarize, int precision) {
this(",", precision);
if(forceSummarize)
localMaxPrintElements = 0;
else
localMaxPrintElements = maxElements;
}
public NDArrayStrings(boolean forceSummarize, int precision) {
this(",", precision);
if(forceSummarize)
localMaxPrintElements = 0;
}
public NDArrayStrings(boolean forceSummarize) {
this(",", 4);
if(forceSummarize)
localMaxPrintElements = 0;
}
/**
* Specify a delimiter for elements in columns for 2d arrays (or in the rank-1th dimension in higher order arrays)
@ -93,13 +127,17 @@ public class NDArrayStrings {
public NDArrayStrings(String colSep, int precision) {
this.colSep = colSep;
if (!colSep.replaceAll("\\s", "").equals(",")) this.newLineSep = "";
this.precision = precision;
String decFormatNum = "0.";
while (precision > 0) {
decFormatNum += "0";
precision -= 1;
StringBuilder decFormatNum = new StringBuilder("0.");
int prec = Math.abs(precision);
this.precision = prec;
boolean useHash = precision < 0;
while (prec > 0) {
decFormatNum.append(useHash ? "#" : "0");
prec -= 1;
}
this.decimalFormat = localeIndifferentDecimalFormat(decFormatNum);
this.decimalFormat = localeIndifferentDecimalFormat(decFormatNum.toString());
}
/**
@ -147,7 +185,7 @@ public class NDArrayStrings {
if (this.scientificFormat.length() + 2 > this.padding) this.padding = this.scientificFormat.length() + 2;
this.maxToPrintWithoutSwitching = Math.pow(10,this.precision);
this.minToPrintWithoutSwitching = 1.0/(this.maxToPrintWithoutSwitching);
return format(arr, 0, summarize && arr.length() > maxPrintElements);
return format(arr, 0, summarize && arr.length() > localMaxPrintElements);
}
private String format(INDArray arr, int offset, boolean summarize) {

View File

@ -0,0 +1,65 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg;
import static org.junit.Assert.assertEquals;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@RunWith(Parameterized.class)
@Slf4j
public class ToStringTest extends BaseNd4jTest {
public ToStringTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void testToString() throws Exception {
assertEquals("[ 1, 2, 3]",
Nd4j.createFromArray(1, 2, 3).toString());
assertEquals("[ 1, 2, 3 ... 6 7, 8]",
Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(1000, false, 2));
assertEquals("[ 1.132, 2.644, 3.234]",
Nd4j.createFromArray(1.132414, 2.64356456, 3.234234).toString(1000, false, 3));
assertEquals("[ 1.132414, 2.64356456, 3.25345234]",
Nd4j.createFromArray(1.132414, 2.64356456, 3.25345234).toStringFull());
assertEquals("[ 1, 2, 3 ... 6 7, 8]",
Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(100, true, 1));
}
@Override
public char ordering() {
return 'c';
}
}

View File

@ -124,6 +124,36 @@ public enum DataType {
}
}
/**
* @return the max number of significant decimal digits
*/
public int precision(){
switch (this){
case DOUBLE:
return 17;
case FLOAT:
return 9;
case HALF:
return 5;
case BFLOAT16:
return 4;
case LONG:
case INT:
case SHORT:
case BYTE:
case UBYTE:
case BOOL:
case UTF8:
case COMPRESSED:
case UINT16:
case UINT32:
case UINT64:
case UNKNOWN:
default:
return -1;
}
}
/**
* @return For fixed-width types, this returns the number of bytes per array element
*/