Nd4j: Change array args to vararg and add toString options (#36)
* changed [] to ... Signed-off-by: Ryan Nett <rnett@skymind.io> * added randn(long seed, int... shape) Signed-off-by: Ryan Nett <rnett@skymind.io> * Fixed a couple of methods Signed-off-by: Ryan Nett <rnett@skymind.io> * ToString methods w/ options Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, less toString methods, and a few ops I missed Signed-off-by: Ryan Nett <rnett@skymind.io> * some javadocs, change int... to long... where possible Signed-off-by: Ryan Nett <rnett@skymind.io> * another javadoc Signed-off-by: Ryan Nett <rnett@skymind.io> * javadoc fix Signed-off-by: Ryan Nett <rnett@skymind.io> * just javadoc in INDArray Signed-off-by: Ryan Nett <rnett@skymind.io> * local/static fix Signed-off-by: Ryan Nett <rnett@skymind.io> * Add @NonNull to options Signed-off-by: Ryan Nett <rnett@skymind.io> * javadoc updates/fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * more @NonNulls Signed-off-by: Ryan Nett <rnett@skymind.io> * even more @NonNulls, this time on varargs Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
366d850f5e
commit
ef8cb55b7b
|
@ -38,7 +38,6 @@ import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
|
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
|
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.*;
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.*;
|
||||||
|
@ -5995,13 +5994,30 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
|
return toString(new NDArrayStrings());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(@NonNull NDArrayStrings options){
|
||||||
if (!isCompressed() && !preventUnpack)
|
if (!isCompressed() && !preventUnpack)
|
||||||
return new NDArrayStrings().format(this);
|
return options.format(this);
|
||||||
else if (isCompressed() && compressDebug)
|
else if (isCompressed() && compressDebug)
|
||||||
return "COMPRESSED ARRAY. SYSTEM PROPERTY compressdebug is true. This is to prevent auto decompression from being triggered.";
|
return "COMPRESSED ARRAY. SYSTEM PROPERTY compressdebug is true. This is to prevent auto decompression from being triggered.";
|
||||||
else if (preventUnpack)
|
else if (preventUnpack)
|
||||||
return "Array string unpacking is disabled.";
|
return "Array string unpacking is disabled.";
|
||||||
return new NDArrayStrings().format(this);
|
return options.format(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(long maxElements, boolean forceSummarize, int precision){
|
||||||
|
return toString(new NDArrayStrings(maxElements, forceSummarize, precision));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toStringFull(){
|
||||||
|
return toString(Long.MAX_VALUE, false, -1 * dataType().precision());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -39,6 +39,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
import org.nd4j.linalg.indexing.SpecifiedIndex;
|
import org.nd4j.linalg.indexing.SpecifiedIndex;
|
||||||
import org.nd4j.linalg.indexing.conditions.Condition;
|
import org.nd4j.linalg.indexing.conditions.Condition;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
import org.nd4j.linalg.string.NDArrayStrings;
|
||||||
import org.nd4j.linalg.util.LinAlgExceptions;
|
import org.nd4j.linalg.util.LinAlgExceptions;
|
||||||
|
|
||||||
import java.nio.LongBuffer;
|
import java.nio.LongBuffer;
|
||||||
|
@ -47,7 +48,9 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.NoSuchElementException;
|
import java.util.NoSuchElementException;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.factory.Nd4j.compressDebug;
|
||||||
import static org.nd4j.linalg.factory.Nd4j.createUninitialized;
|
import static org.nd4j.linalg.factory.Nd4j.createUninitialized;
|
||||||
|
import static org.nd4j.linalg.factory.Nd4j.preventUnpack;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author Audrey Loeffel
|
* @author Audrey Loeffel
|
||||||
|
@ -2016,4 +2019,24 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
|
||||||
public INDArray ulike() {
|
public INDArray ulike() {
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
throw new UnsupportedOperationException("Not yet implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(NDArrayStrings options){
|
||||||
|
return "SPARSE ARRAY, TOSTRING NOT SUPPORTED";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(long maxElements, boolean forceSummarize, int decimalPlaces){
|
||||||
|
return toString(new NDArrayStrings(maxElements, forceSummarize, decimalPlaces));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toStringFull(){
|
||||||
|
return toString(Long.MAX_VALUE, false, dataType().precision());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(){
|
||||||
|
return toString(new NDArrayStrings());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,11 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ndarray;
|
package org.nd4j.linalg.api.ndarray;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.factory.Nd4j.compressDebug;
|
||||||
|
import static org.nd4j.linalg.factory.Nd4j.preventUnpack;
|
||||||
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -28,6 +32,7 @@ import org.nd4j.linalg.indexing.conditions.Condition;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.nio.LongBuffer;
|
import java.nio.LongBuffer;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import org.nd4j.linalg.string.NDArrayStrings;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Interface for an ndarray
|
* Interface for an ndarray
|
||||||
|
@ -1914,7 +1919,7 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
* @param shape
|
* @param shape
|
||||||
* @param stride
|
* @param stride
|
||||||
*/
|
*/
|
||||||
public void setShapeAndStride(int[] shape, int[] stride);
|
void setShapeAndStride(int[] shape, int[] stride);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the ordering
|
* Set the ordering
|
||||||
|
@ -2843,4 +2848,26 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
//INDArray[] gains(INDArray input, INDArray gradx, INDArray epsilon);
|
//INDArray[] gains(INDArray input, INDArray gradx, INDArray epsilon);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a string representation of the array with configurable formatting
|
||||||
|
* @param options format options
|
||||||
|
*/
|
||||||
|
String toString(@NonNull NDArrayStrings options);
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a string representation of the array
|
||||||
|
*
|
||||||
|
* @param maxElements Summarize if more than maxElements in the array
|
||||||
|
* @param forceSummarize Force a summary instead of a full print
|
||||||
|
* @param precision The number of decimals to print. Doesn't print trailing 0s if negative
|
||||||
|
*/
|
||||||
|
String toString(long maxElements, boolean forceSummarize, int precision);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ToString with unlimited elements and precision
|
||||||
|
* @see org.nd4j.linalg.api.ndarray.BaseNDArray#toString(long, boolean, int)
|
||||||
|
*/
|
||||||
|
String toStringFull();
|
||||||
}
|
}
|
||||||
|
|
|
@ -407,7 +407,7 @@ public class Nd4j {
|
||||||
* @param dimension the dimension to do the shuffle
|
* @param dimension the dimension to do the shuffle
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static void shuffle(INDArray toShuffle, Random random, int... dimension) {
|
public static void shuffle(INDArray toShuffle, Random random, @NonNull int... dimension) {
|
||||||
INSTANCE.shuffle(toShuffle, random, dimension);
|
INSTANCE.shuffle(toShuffle, random, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -418,7 +418,7 @@ public class Nd4j {
|
||||||
* @param dimension the dimension to do the shuffle
|
* @param dimension the dimension to do the shuffle
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static void shuffle(INDArray toShuffle, int... dimension) {
|
public static void shuffle(INDArray toShuffle, @NonNull int... dimension) {
|
||||||
//shuffle(toShuffle, new Random(), dimension);
|
//shuffle(toShuffle, new Random(), dimension);
|
||||||
INSTANCE.shuffle(toShuffle, new Random(), dimension);
|
INSTANCE.shuffle(toShuffle, new Random(), dimension);
|
||||||
}
|
}
|
||||||
|
@ -430,7 +430,7 @@ public class Nd4j {
|
||||||
* @param dimension the dimension to do the shuffle
|
* @param dimension the dimension to do the shuffle
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static void shuffle(Collection<INDArray> toShuffle, int... dimension) {
|
public static void shuffle(Collection<INDArray> toShuffle, @NonNull int... dimension) {
|
||||||
//shuffle(toShuffle, new Random(), dimension);
|
//shuffle(toShuffle, new Random(), dimension);
|
||||||
INSTANCE.shuffle(toShuffle, new Random(), dimension);
|
INSTANCE.shuffle(toShuffle, new Random(), dimension);
|
||||||
}
|
}
|
||||||
|
@ -442,7 +442,7 @@ public class Nd4j {
|
||||||
* @param dimension the dimension to do the shuffle
|
* @param dimension the dimension to do the shuffle
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static void shuffle(Collection<INDArray> toShuffle, Random rnd, int... dimension) {
|
public static void shuffle(Collection<INDArray> toShuffle, Random rnd, @NonNull int... dimension) {
|
||||||
//shuffle(toShuffle, new Random(), dimension);
|
//shuffle(toShuffle, new Random(), dimension);
|
||||||
INSTANCE.shuffle(toShuffle, rnd, dimension);
|
INSTANCE.shuffle(toShuffle, rnd, dimension);
|
||||||
}
|
}
|
||||||
|
@ -659,7 +659,7 @@ public class Nd4j {
|
||||||
* @param dimension the dimension along which to get the maximum
|
* @param dimension the dimension along which to get the maximum
|
||||||
* @return array of maximum values.
|
* @return array of maximum values.
|
||||||
*/
|
*/
|
||||||
public static INDArray argMax(INDArray arr, int... dimension) {
|
public static INDArray argMax(INDArray arr, @NonNull int... dimension) {
|
||||||
IMax imax = new IMax(arr, dimension);
|
IMax imax = new IMax(arr, dimension);
|
||||||
return Nd4j.getExecutioner().exec(imax);
|
return Nd4j.getExecutioner().exec(imax);
|
||||||
}
|
}
|
||||||
|
@ -667,7 +667,7 @@ public class Nd4j {
|
||||||
/**
|
/**
|
||||||
* @see #argMax(INDArray, int...)
|
* @see #argMax(INDArray, int...)
|
||||||
*/
|
*/
|
||||||
public static INDArray argMin(INDArray arr, int... dimension) {
|
public static INDArray argMin(INDArray arr, @NonNull int... dimension) {
|
||||||
IMin imin = new IMin(arr, dimension);
|
IMin imin = new IMin(arr, dimension);
|
||||||
return Nd4j.getExecutioner().exec(imin);
|
return Nd4j.getExecutioner().exec(imin);
|
||||||
}
|
}
|
||||||
|
@ -2325,7 +2325,7 @@ public class Nd4j {
|
||||||
* @return the flattened representation of
|
* @return the flattened representation of
|
||||||
* these ndarrays
|
* these ndarrays
|
||||||
*/
|
*/
|
||||||
public static INDArray toFlattened(INDArray... matrices) {
|
public static INDArray toFlattened(@NonNull INDArray... matrices) {
|
||||||
return INSTANCE.toFlattened(matrices);
|
return INSTANCE.toFlattened(matrices);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2337,7 +2337,7 @@ public class Nd4j {
|
||||||
* @return the flattened representation of
|
* @return the flattened representation of
|
||||||
* these ndarrays
|
* these ndarrays
|
||||||
*/
|
*/
|
||||||
public static INDArray toFlattened(char order, INDArray... matrices) {
|
public static INDArray toFlattened(char order, @NonNull INDArray... matrices) {
|
||||||
return INSTANCE.toFlattened(order, matrices);
|
return INSTANCE.toFlattened(order, matrices);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3071,7 +3071,7 @@ public class Nd4j {
|
||||||
return choice(source, probs, numSamples, Nd4j.getRandom());
|
return choice(source, probs, numSamples, Nd4j.getRandom());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray appendBias(INDArray... vectors) {
|
public static INDArray appendBias(@NonNull INDArray... vectors) {
|
||||||
INDArray ret = INSTANCE.appendBias(vectors);
|
INDArray ret = INSTANCE.appendBias(vectors);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -3092,13 +3092,12 @@ public class Nd4j {
|
||||||
////////////////////// RANDOM ///////////////////////////////
|
////////////////////// RANDOM ///////////////////////////////
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a random ndarray with the given shape using
|
* Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape
|
||||||
* the current time as the seed
|
|
||||||
*
|
*
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(int[] shape) {
|
public static INDArray rand(@NonNull int... shape) {
|
||||||
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom());
|
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom());
|
||||||
return rand(ret);
|
return rand(ret);
|
||||||
}
|
}
|
||||||
|
@ -3106,18 +3105,20 @@ public class Nd4j {
|
||||||
/**
|
/**
|
||||||
* @see #rand(int[])
|
* @see #rand(int[])
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(long[] shape) {
|
public static INDArray rand(@NonNull long... shape) {
|
||||||
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom());
|
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom());
|
||||||
return rand(ret);
|
return rand(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a random ndarray with given type and shape.
|
* Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape and data type
|
||||||
* @param dataType datatype
|
*
|
||||||
* @param shape shape
|
* @param shape the shape of the ndarray
|
||||||
* @return new array.
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(DataType dataType, long... shape) {
|
public static INDArray rand(@NonNull DataType dataType, @NonNull long... shape) {
|
||||||
|
Preconditions.checkArgument(dataType.isFPType(),
|
||||||
|
"Can't create a random array of a non-floating point data type");
|
||||||
INDArray ret = createUninitialized(dataType, shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom());
|
INDArray ret = createUninitialized(dataType, shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom());
|
||||||
return rand(ret);
|
return rand(ret);
|
||||||
}
|
}
|
||||||
|
@ -3125,52 +3126,65 @@ public class Nd4j {
|
||||||
/**
|
/**
|
||||||
* Create a random ndarray with the given shape and array order
|
* Create a random ndarray with the given shape and array order
|
||||||
*
|
*
|
||||||
|
* Values are sampled from a uniform distribution over (0, 1)
|
||||||
|
*
|
||||||
* @param order the order of the ndarray to return
|
* @param order the order of the ndarray to return
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(char order, int[] shape) {
|
public static INDArray rand(char order, @NonNull int... shape) {
|
||||||
INDArray ret = Nd4j.createUninitialized(shape, order); //INSTANCE.rand(order, shape);
|
INDArray ret = Nd4j.createUninitialized(shape, order); //INSTANCE.rand(order, shape);
|
||||||
return rand(ret);
|
return rand(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a random ndarray with the given datatype, order and shape.
|
* @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.buffer.DataType, char, long...))
|
||||||
*
|
|
||||||
* The datatype must be one of the floating point types.
|
|
||||||
*
|
|
||||||
* @param dataType datatype
|
|
||||||
* @param order order
|
|
||||||
* @param shape shape
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(DataType dataType, char order, int[] shape) {
|
@Deprecated
|
||||||
INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), order); //INSTANCE.rand(order, shape);
|
public static INDArray rand(@NonNull DataType dataType, int[] shape, char order) {
|
||||||
return rand(ret);
|
return rand(dataType, order, ArrayUtil.toLongArray(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @see #rand(DataType, char, int[])
|
* @deprecated use {@link org.nd4j.linalg.factory.Nd4j#rand(org.nd4j.linalg.api.buffer.DataType, char, long...)}
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(DataType dataType, int[] shape, char order) {
|
@Deprecated
|
||||||
INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), order); //INSTANCE.rand(order, shape);
|
public static INDArray rand(@NonNull DataType dataType, char order, @NonNull int... shape) {
|
||||||
return rand(ret);
|
return rand(dataType, order, ArrayUtil.toLongArray(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a random ndarray with the given datatype and shape.
|
* Create a random ndarray with the given shape, data type, and array order
|
||||||
* using the default Nd4j order.
|
|
||||||
*
|
*
|
||||||
* @see #rand(DataType, char, int[])
|
* Values are sampled from a uniform distribution over (0, 1)
|
||||||
|
*
|
||||||
|
* @param order the order of the ndarray to return
|
||||||
|
* @param shape the shape of the ndarray
|
||||||
|
* @param dataType the data type of the ndarray
|
||||||
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(DataType dataType, int[] shape) {
|
public static INDArray rand(@NonNull DataType dataType, char order, @NonNull long... shape) {
|
||||||
|
INDArray ret = Nd4j.createUninitialized(dataType, shape, order); //INSTANCE.rand(order, shape);
|
||||||
|
return rand(ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a random ndarray with the given shape and data type
|
||||||
|
*
|
||||||
|
* Values are sampled from a uniform distribution over (0, 1)
|
||||||
|
*
|
||||||
|
* @param shape the shape of the ndarray
|
||||||
|
* @param dataType the data type of the ndarray
|
||||||
|
* @return the random ndarray with the specified shape
|
||||||
|
*/
|
||||||
|
public static INDArray rand(@NonNull DataType dataType, @NonNull int... shape) {
|
||||||
INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order()); //INSTANCE.rand(order, shape);
|
INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order()); //INSTANCE.rand(order, shape);
|
||||||
return rand(ret);
|
return rand(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a random ndarray with the given shape using
|
* Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape
|
||||||
* the current time as the seed
|
|
||||||
*
|
*
|
||||||
* @param rows the number of rows in the matrix
|
* @param rows the number of rows in the matrix
|
||||||
* @param columns the number of columns in the matrix
|
* @param columns the number of columns in the matrix
|
||||||
|
@ -3187,6 +3201,8 @@ public class Nd4j {
|
||||||
/**
|
/**
|
||||||
* Create a random ndarray with the given shape and output order
|
* Create a random ndarray with the given shape and output order
|
||||||
*
|
*
|
||||||
|
* Values are sampled from a uniform distribution over (0, 1)
|
||||||
|
*
|
||||||
* @param rows the number of rows in the matrix
|
* @param rows the number of rows in the matrix
|
||||||
* @param columns the number of columns in the matrix
|
* @param columns the number of columns in the matrix
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
|
@ -3200,20 +3216,31 @@ public class Nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a random ndarray with the given shape using given seed
|
* Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape
|
||||||
|
* using given seed
|
||||||
*
|
*
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @param seed the seed to use
|
* @param seed the seed to use
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(int[] shape, long seed) {
|
public static INDArray rand(long seed, @NonNull long... shape) {
|
||||||
INDArray ret = createUninitialized(shape, Nd4j.order());//;INSTANCE.rand(shape, seed);
|
INDArray ret = createUninitialized(shape, Nd4j.order());//;INSTANCE.rand(shape, seed);
|
||||||
return rand(ret, seed);
|
return rand(ret, seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated use {@link Nd4j#rand(long, long...)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public static INDArray rand(int[] shape, long seed) {
|
||||||
|
INDArray ret = createUninitialized(shape, Nd4j.order());//;INSTANCE.rand(shape, seed);
|
||||||
|
return rand(seed, ArrayUtil.toLongArray(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a random ndarray with the given shape using the given seed
|
* Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape
|
||||||
|
* using the given seed
|
||||||
*
|
*
|
||||||
* @param rows the number of rows in the matrix
|
* @param rows the number of rows in the matrix
|
||||||
* @param columns the columns of the ndarray
|
* @param columns the columns of the ndarray
|
||||||
|
@ -3225,6 +3252,14 @@ public class Nd4j {
|
||||||
return rand(ret, seed);
|
return rand(ret, seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.rng.Random, long...)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public static INDArray rand(int[] shape, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||||
|
return rand(rng, ArrayUtil.toLongArray(shape));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a random ndarray with the given shape using the given RandomGenerator
|
* Create a random ndarray with the given shape using the given RandomGenerator
|
||||||
*
|
*
|
||||||
|
@ -3232,22 +3267,26 @@ public class Nd4j {
|
||||||
* @param rng the random generator to use
|
* @param rng the random generator to use
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(int[] shape, org.nd4j.linalg.api.rng.Random rng) {
|
public static INDArray rand(@NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) {
|
||||||
INDArray ret = createUninitialized(shape, Nd4j.order()); //INSTANCE.rand(shape, rng);
|
INDArray ret = createUninitialized(shape, Nd4j.order()); //INSTANCE.rand(shape, rng);
|
||||||
return rand(ret, rng);
|
return rand(ret, rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a random ndarray with the given shape using the given rng
|
* @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.rng.distribution.Distribution, long...)}
|
||||||
*
|
|
||||||
* @param shape the shape of the array
|
|
||||||
* @param dist distribution to use
|
|
||||||
* @return the random ndarray with the specified shape
|
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(int[] shape, Distribution dist) {
|
@Deprecated
|
||||||
//INDArray ret = INSTANCE.rand(shape, dist);
|
public static INDArray rand(int[] shape, @NonNull Distribution dist) {
|
||||||
//logCreationIfNecessary(ret);
|
return rand(dist, ArrayUtil.toLongArray(shape));
|
||||||
return dist.sample(shape);
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated use
|
||||||
|
* {@link org.nd4j.linalg.factory.Nd4j#rand(org.nd4j.linalg.api.rng.distribution.Distribution, long...)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public static INDArray rand(long[] shape, @NonNull Distribution dist) {
|
||||||
|
return rand(dist, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -3257,7 +3296,7 @@ public class Nd4j {
|
||||||
* @param dist distribution to use
|
* @param dist distribution to use
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(long[] shape, Distribution dist) {
|
public static INDArray rand(@NonNull Distribution dist, @NonNull long... shape) {
|
||||||
//INDArray ret = INSTANCE.rand(shape, dist);
|
//INDArray ret = INSTANCE.rand(shape, dist);
|
||||||
//logCreationIfNecessary(ret);
|
//logCreationIfNecessary(ret);
|
||||||
return dist.sample(shape);
|
return dist.sample(shape);
|
||||||
|
@ -3271,21 +3310,24 @@ public class Nd4j {
|
||||||
* @param rng the random generator to use
|
* @param rng the random generator to use
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(int rows, int columns, org.nd4j.linalg.api.rng.Random rng) {
|
public static INDArray rand(int rows, int columns, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||||
INDArray ret = createUninitialized(new int[] {rows, columns}, order());//INSTANCE.rand(rows, columns, rng);
|
INDArray ret = createUninitialized(new int[] {rows, columns}, order());//INSTANCE.rand(rows, columns, rng);
|
||||||
return rand(ret, rng);
|
return rand(ret, rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generates a random matrix between min and max
|
* @deprecated use {@link Nd4j#rand(double, double, org.nd4j.linalg.api.rng.Random, long...)}
|
||||||
*
|
|
||||||
* @param shape the number of rows of the matrix
|
|
||||||
* @param min the minimum number
|
|
||||||
* @param max the maximum number
|
|
||||||
* @param rng the rng to use
|
|
||||||
* @return a random matrix of the specified shape and range
|
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(int[] shape, double min, double max, org.nd4j.linalg.api.rng.Random rng) {
|
@Deprecated
|
||||||
|
public static INDArray rand(int[] shape, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||||
|
return rand(min, max, rng, ArrayUtil.toLongArray(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated use {@link Nd4j#rand(double, double, org.nd4j.linalg.api.rng.Random, long...)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public static INDArray rand(long[] shape, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||||
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng);
|
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng);
|
||||||
return rand(ret, min, max, rng);
|
return rand(ret, min, max, rng);
|
||||||
}
|
}
|
||||||
|
@ -3299,7 +3341,7 @@ public class Nd4j {
|
||||||
* @param rng the rng to use
|
* @param rng the rng to use
|
||||||
* @return a random matrix of the specified shape and range
|
* @return a random matrix of the specified shape and range
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(long[] shape, double min, double max, org.nd4j.linalg.api.rng.Random rng) {
|
public static INDArray rand(double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) {
|
||||||
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng);
|
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng);
|
||||||
return rand(ret, min, max, rng);
|
return rand(ret, min, max, rng);
|
||||||
}
|
}
|
||||||
|
@ -3314,7 +3356,7 @@ public class Nd4j {
|
||||||
* @param rng the rng to use
|
* @param rng the rng to use
|
||||||
* @return a drandom matrix of the specified shape and range
|
* @return a drandom matrix of the specified shape and range
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(int rows, int columns, double min, double max, org.nd4j.linalg.api.rng.Random rng) {
|
public static INDArray rand(int rows, int columns, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||||
INDArray ret = createUninitialized(rows, columns);//INSTANCE.rand(rows, columns, min, max, rng);
|
INDArray ret = createUninitialized(rows, columns);//INSTANCE.rand(rows, columns, min, max, rng);
|
||||||
return rand(ret, min, max, rng);
|
return rand(ret, min, max, rng);
|
||||||
}
|
}
|
||||||
|
@ -3330,34 +3372,47 @@ public class Nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Random normal using the current time stamp
|
* Create a ndarray of the given shape with values from N(0,1)
|
||||||
* as the seed
|
|
||||||
*
|
*
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return new array with random values
|
* @return new array with random values
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(int[] shape) {
|
public static INDArray randn(@NonNull int... shape) {
|
||||||
INDArray ret = Nd4j.createUninitialized(shape, order());
|
return randn(ArrayUtil.toLongArray(shape));
|
||||||
return randn(ret);
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a ndarray of the given shape and data type with values from N(0,1)
|
||||||
|
*
|
||||||
|
* @param shape the shape of the ndarray
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static INDArray randn(@NonNull DataType dataType, @NonNull int... shape) {
|
||||||
|
return randn(dataType, ArrayUtil.toLongArray(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Random normal ndarray of given datatype and shape,
|
* Create a ndarray of the given shape and data type with values from N(0,1)
|
||||||
|
*
|
||||||
* @param dataType datatype to use, must be a float type datatype.
|
* @param dataType datatype to use, must be a float type datatype.
|
||||||
* @param shape shape for the new array.
|
* @param shape shape for the new array.
|
||||||
* @return new array with random values
|
* @return new array with random values
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(DataType dataType, long... shape) {
|
public static INDArray randn(@NonNull DataType dataType, @NonNull long... shape) {
|
||||||
INDArray ret = Nd4j.createUninitialized(dataType, shape, order());
|
INDArray ret = Nd4j.createUninitialized(dataType, shape, order());
|
||||||
return randn(ret);
|
return randn(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Random normal ndarray of given shape. defaults to FLOAT and c-order.
|
* Create a ndarray of the given shape with values from N(0,1).
|
||||||
|
* Defaults to FLOAT and c-order.
|
||||||
|
*
|
||||||
* @param shape shape for the new array.
|
* @param shape shape for the new array.
|
||||||
* @return new array with random values
|
* @return new array with random values
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(long... shape) {
|
public static INDArray randn(@NonNull long... shape) {
|
||||||
INDArray ret = Nd4j.createUninitialized(shape, order());
|
INDArray ret = Nd4j.createUninitialized(shape, order());
|
||||||
return randn(ret);
|
return randn(ret);
|
||||||
}
|
}
|
||||||
|
@ -3369,7 +3424,7 @@ public class Nd4j {
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return new array with random values
|
* @return new array with random values
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(char order, int[] shape) {
|
public static INDArray randn(char order, @NonNull int... shape) {
|
||||||
INDArray ret = Nd4j.createUninitialized(shape, order);
|
INDArray ret = Nd4j.createUninitialized(shape, order);
|
||||||
return randn(ret);
|
return randn(ret);
|
||||||
}
|
}
|
||||||
|
@ -3381,34 +3436,45 @@ public class Nd4j {
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return new array with random values
|
* @return new array with random values
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(char order, long[] shape) {
|
public static INDArray randn(char order, @NonNull long... shape) {
|
||||||
INDArray ret = Nd4j.createUninitialized(shape, order);
|
INDArray ret = Nd4j.createUninitialized(shape, order);
|
||||||
return randn(ret);
|
return randn(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @see #rand(DataType, char, int[])
|
* Random normal N(0,1) with the specified shape and array order
|
||||||
|
*
|
||||||
|
* @param order order of the output ndarray
|
||||||
|
* @param shape the shape of the ndarray
|
||||||
|
* @param dataType the data type of the ndarray
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(DataType dataType, char order, long[] shape) {
|
public static INDArray randn(@NonNull DataType dataType, char order, @NonNull long... shape) {
|
||||||
INDArray ret = Nd4j.createUninitialized(dataType, shape, order);
|
INDArray ret = Nd4j.createUninitialized(dataType, shape, order);
|
||||||
return randn(ret);
|
return randn(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Random normal using the specified seed
|
* @deprecated use {@link Nd4j#randn(long, long...)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public static INDArray randn(int[] shape, long seed) {
|
||||||
|
return randn(seed, ArrayUtil.toLongArray(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Random normal N(0, 1) using the specified seed
|
||||||
*
|
*
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(int[] shape, long seed) {
|
public static INDArray randn(long seed, @NonNull long... shape) {
|
||||||
INDArray ret = Nd4j.createUninitialized(shape, order());
|
INDArray ret = Nd4j.createUninitialized(shape, order());
|
||||||
return randn(ret, seed);
|
return randn(ret, seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Random normal using the current time stamp
|
* Random normal N(0, 1)
|
||||||
* as the seed
|
|
||||||
*
|
*
|
||||||
* @param rows the number of rows in the matrix
|
* @param rows the number of rows in the matrix
|
||||||
* @param columns the number of columns in the matrix
|
* @param columns the number of columns in the matrix
|
||||||
|
@ -3451,21 +3517,25 @@ public class Nd4j {
|
||||||
* @param r the random generator to use
|
* @param r the random generator to use
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(long rows, long columns, org.nd4j.linalg.api.rng.Random r) {
|
public static INDArray randn(long rows, long columns, @NonNull org.nd4j.linalg.api.rng.Random r) {
|
||||||
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order());
|
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order());
|
||||||
return randn(ret, r);
|
return randn(ret, r);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Random normal using the given rng
|
* @deprecated use {@link Nd4j#randn(org.nd4j.linalg.api.rng.Random, long...)}
|
||||||
*
|
|
||||||
* @param shape the shape of the array
|
|
||||||
* @param r the random generator to use
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(int[] shape, org.nd4j.linalg.api.rng.Random r) {
|
@Deprecated
|
||||||
final INDArray ret = Nd4j.createUninitialized(shape, order());
|
public static INDArray randn(int[] shape, @NonNull org.nd4j.linalg.api.rng.Random r) {
|
||||||
return randn(ret, r);
|
return randn(r, ArrayUtil.toLongArray(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated use {@link Nd4j#randn(org.nd4j.linalg.api.rng.Random, long...)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public static INDArray randn(long[] shape, @NonNull org.nd4j.linalg.api.rng.Random r) {
|
||||||
|
return randn(r, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -3475,7 +3545,7 @@ public class Nd4j {
|
||||||
* @param r the random generator to use
|
* @param r the random generator to use
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(long[] shape, org.nd4j.linalg.api.rng.Random r) {
|
public static INDArray randn(@NonNull org.nd4j.linalg.api.rng.Random r, @NonNull long... shape) {
|
||||||
final INDArray ret = Nd4j.createUninitialized(shape, order());
|
final INDArray ret = Nd4j.createUninitialized(shape, order());
|
||||||
return randn(ret, r);
|
return randn(ret, r);
|
||||||
}
|
}
|
||||||
|
@ -3509,7 +3579,7 @@ public class Nd4j {
|
||||||
* @param rng the random generator to use
|
* @param rng the random generator to use
|
||||||
* @return the given target array
|
* @return the given target array
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(INDArray target, org.nd4j.linalg.api.rng.Random rng) {
|
public static INDArray rand(INDArray target, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||||
return getExecutioner().exec(new UniformDistribution(target), rng);
|
return getExecutioner().exec(new UniformDistribution(target), rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3520,7 +3590,7 @@ public class Nd4j {
|
||||||
* @param dist distribution to use
|
* @param dist distribution to use
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(INDArray target, Distribution dist) {
|
public static INDArray rand(INDArray target, @NonNull Distribution dist) {
|
||||||
return dist.sample(target);
|
return dist.sample(target);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3533,7 +3603,7 @@ public class Nd4j {
|
||||||
* @param rng the random generator to use
|
* @param rng the random generator to use
|
||||||
* @return the given target array
|
* @return the given target array
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(INDArray target, double min, double max, org.nd4j.linalg.api.rng.Random rng) {
|
public static INDArray rand(INDArray target, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||||
if (min > max)
|
if (min > max)
|
||||||
throw new IllegalArgumentException("the maximum value supplied is smaller than the minimum");
|
throw new IllegalArgumentException("the maximum value supplied is smaller than the minimum");
|
||||||
return getExecutioner().exec(new UniformDistribution(target, min, max), rng);
|
return getExecutioner().exec(new UniformDistribution(target, min, max), rng);
|
||||||
|
@ -3557,7 +3627,7 @@ public class Nd4j {
|
||||||
* @param rng the random generator to use
|
* @param rng the random generator to use
|
||||||
* @return the given target array
|
* @return the given target array
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(INDArray target, org.nd4j.linalg.api.rng.Random rng) {
|
public static INDArray randn(INDArray target, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||||
return getExecutioner().exec(new GaussianDistribution(target), rng);
|
return getExecutioner().exec(new GaussianDistribution(target), rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3569,7 +3639,7 @@ public class Nd4j {
|
||||||
* @param shape Shape of the result array
|
* @param shape Shape of the result array
|
||||||
* @return Result array
|
* @return Result array
|
||||||
*/
|
*/
|
||||||
public static INDArray randomBernoulli(double p, long... shape) {
|
public static INDArray randomBernoulli(double p, @NonNull long... shape) {
|
||||||
return randomBernoulli(p, Nd4j.createUninitialized(shape));
|
return randomBernoulli(p, Nd4j.createUninitialized(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3595,7 +3665,7 @@ public class Nd4j {
|
||||||
* @param shape Shape of the result array
|
* @param shape Shape of the result array
|
||||||
* @return Result array
|
* @return Result array
|
||||||
*/
|
*/
|
||||||
public static INDArray randomBinomial(int nTrials, double p, long... shape) {
|
public static INDArray randomBinomial(int nTrials, double p, @NonNull long... shape) {
|
||||||
return randomBinomial(nTrials, p, Nd4j.createUninitialized(shape));
|
return randomBinomial(nTrials, p, Nd4j.createUninitialized(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3757,7 +3827,7 @@ public class Nd4j {
|
||||||
* @param shape desired shape of new array.
|
* @param shape desired shape of new array.
|
||||||
* @return the created ndarray.
|
* @return the created ndarray.
|
||||||
*/
|
*/
|
||||||
public static INDArray create(boolean[][] data, long[] shape) {
|
public static INDArray create(boolean[][] data, @NonNull long... shape) {
|
||||||
return INSTANCE.create(ArrayUtil.flatten(data), shape, getStrides(shape), DataType.BOOL, Nd4j.getMemoryManager().getCurrentWorkspace());
|
return INSTANCE.create(ArrayUtil.flatten(data), shape, getStrides(shape), DataType.BOOL, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4083,7 +4153,7 @@ public class Nd4j {
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return the created ndarray
|
* @return the created ndarray
|
||||||
*/
|
*/
|
||||||
public static INDArray create(float[] data, int[] shape) {
|
public static INDArray create(float[] data, int... shape) {
|
||||||
if (shape.length == 0 && data.length == 1) {
|
if (shape.length == 0 && data.length == 1) {
|
||||||
return scalar(data[0]);
|
return scalar(data[0]);
|
||||||
}
|
}
|
||||||
|
@ -4102,7 +4172,7 @@ public class Nd4j {
|
||||||
/**
|
/**
|
||||||
* @see #create(float[], int[])
|
* @see #create(float[], int[])
|
||||||
*/
|
*/
|
||||||
public static INDArray create(float[] data, long[] shape) {
|
public static INDArray create(float[] data, long... shape) {
|
||||||
if (shape.length == 0 && data.length == 1) {
|
if (shape.length == 0 && data.length == 1) {
|
||||||
return scalar(data[0]);
|
return scalar(data[0]);
|
||||||
}
|
}
|
||||||
|
@ -4121,7 +4191,7 @@ public class Nd4j {
|
||||||
/**
|
/**
|
||||||
* @see #create(float[], int[])
|
* @see #create(float[], int[])
|
||||||
*/
|
*/
|
||||||
public static INDArray create(double[] data, long[] shape) {
|
public static INDArray create(double[] data, long... shape) {
|
||||||
if (shape.length == 0 && data.length == 1) {
|
if (shape.length == 0 && data.length == 1) {
|
||||||
return scalar(data[0]);
|
return scalar(data[0]);
|
||||||
}
|
}
|
||||||
|
@ -4144,7 +4214,7 @@ public class Nd4j {
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return the created ndarray
|
* @return the created ndarray
|
||||||
*/
|
*/
|
||||||
public static INDArray create(double[] data, int[] shape) {
|
public static INDArray create(double[] data, int... shape) {
|
||||||
if (shape.length == 1) {
|
if (shape.length == 1) {
|
||||||
if (shape[0] != data.length)
|
if (shape[0] != data.length)
|
||||||
throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length);
|
throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length);
|
||||||
|
@ -4346,7 +4416,7 @@ public class Nd4j {
|
||||||
* @param shape desired shape of new array. Must match the resulting shape of combining the list.
|
* @param shape desired shape of new array. Must match the resulting shape of combining the list.
|
||||||
* @return the instance
|
* @return the instance
|
||||||
*/
|
*/
|
||||||
public static INDArray create(List<INDArray> list, int[] shape) {
|
public static INDArray create(List<INDArray> list, int... shape) {
|
||||||
checkShapeValues(shape);
|
checkShapeValues(shape);
|
||||||
|
|
||||||
INDArray ret = INSTANCE.create(list, shape);
|
INDArray ret = INSTANCE.create(list, shape);
|
||||||
|
@ -4356,7 +4426,7 @@ public class Nd4j {
|
||||||
/**
|
/**
|
||||||
* @see #create(List, int[])
|
* @see #create(List, int[])
|
||||||
*/
|
*/
|
||||||
public static INDArray create(List<INDArray> list, long[] shape) {
|
public static INDArray create(List<INDArray> list, long... shape) {
|
||||||
checkShapeValues(shape);
|
checkShapeValues(shape);
|
||||||
|
|
||||||
INDArray ret = INSTANCE.create(list, shape);
|
INDArray ret = INSTANCE.create(list, shape);
|
||||||
|
@ -4641,7 +4711,7 @@ public class Nd4j {
|
||||||
* @param shape desired shape of new array.
|
* @param shape desired shape of new array.
|
||||||
* @return the created ndarray.
|
* @return the created ndarray.
|
||||||
*/
|
*/
|
||||||
public static INDArray create(DataBuffer data, int[] shape) {
|
public static INDArray create(DataBuffer data, int... shape) {
|
||||||
checkShapeValues(shape);
|
checkShapeValues(shape);
|
||||||
return INSTANCE.create(data, shape);
|
return INSTANCE.create(data, shape);
|
||||||
}
|
}
|
||||||
|
@ -4649,7 +4719,7 @@ public class Nd4j {
|
||||||
/**
|
/**
|
||||||
* @see #create(DataBuffer, int[])
|
* @see #create(DataBuffer, int[])
|
||||||
*/
|
*/
|
||||||
public static INDArray create(DataBuffer data, long[] shape) {
|
public static INDArray create(DataBuffer data, long... shape) {
|
||||||
checkShapeValues(shape);
|
checkShapeValues(shape);
|
||||||
return INSTANCE.create(data, shape);
|
return INSTANCE.create(data, shape);
|
||||||
}
|
}
|
||||||
|
@ -5049,7 +5119,7 @@ public class Nd4j {
|
||||||
*
|
*
|
||||||
* @param shape
|
* @param shape
|
||||||
*/
|
*/
|
||||||
public static void checkShapeValues(long[] shape) {
|
public static void checkShapeValues(long... shape) {
|
||||||
for (long e: shape) {
|
for (long e: shape) {
|
||||||
if (e < 0)
|
if (e < 0)
|
||||||
throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape)
|
throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape)
|
||||||
|
@ -5062,7 +5132,7 @@ public class Nd4j {
|
||||||
*
|
*
|
||||||
* @param shape
|
* @param shape
|
||||||
*/
|
*/
|
||||||
public static void checkShapeValues(int[] shape) {
|
public static void checkShapeValues(int... shape) {
|
||||||
for (int e: shape) {
|
for (int e: shape) {
|
||||||
if (e < 1)
|
if (e < 1)
|
||||||
throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape)
|
throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape)
|
||||||
|
@ -5070,7 +5140,7 @@ public class Nd4j {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected static void checkShapeValues(int length, int[] shape) {
|
protected static void checkShapeValues(int length, int... shape) {
|
||||||
checkShapeValues(shape);
|
checkShapeValues(shape);
|
||||||
|
|
||||||
if (ArrayUtil.prodLong(shape) > length)
|
if (ArrayUtil.prodLong(shape) > length)
|
||||||
|
@ -5078,7 +5148,7 @@ public class Nd4j {
|
||||||
+ " doesn't match data length: " + length);
|
+ " doesn't match data length: " + length);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected static void checkShapeValues(int length, long[] shape) {
|
protected static void checkShapeValues(int length, long... shape) {
|
||||||
checkShapeValues(shape);
|
checkShapeValues(shape);
|
||||||
|
|
||||||
if (ArrayUtil.prodLong(shape) > length)
|
if (ArrayUtil.prodLong(shape) > length)
|
||||||
|
@ -5171,7 +5241,7 @@ public class Nd4j {
|
||||||
/**
|
/**
|
||||||
* @see #createUninitialized(long[])
|
* @see #createUninitialized(long[])
|
||||||
*/
|
*/
|
||||||
public static INDArray createUninitialized(int[] shape) {
|
public static INDArray createUninitialized(int... shape) {
|
||||||
if(shape.length == 0)
|
if(shape.length == 0)
|
||||||
return Nd4j.scalar(dataType(), 0.0);
|
return Nd4j.scalar(dataType(), 0.0);
|
||||||
checkShapeValues(shape);
|
checkShapeValues(shape);
|
||||||
|
@ -5186,7 +5256,7 @@ public class Nd4j {
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return the instance
|
* @return the instance
|
||||||
*/
|
*/
|
||||||
public static INDArray createUninitialized(long[] shape) {
|
public static INDArray createUninitialized(long... shape) {
|
||||||
checkShapeValues(shape);
|
checkShapeValues(shape);
|
||||||
//ensure shapes that wind up being scalar end up with the write shape
|
//ensure shapes that wind up being scalar end up with the write shape
|
||||||
return createUninitialized(shape, Nd4j.order());
|
return createUninitialized(shape, Nd4j.order());
|
||||||
|
@ -5199,7 +5269,7 @@ public class Nd4j {
|
||||||
* @param shape
|
* @param shape
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray createUninitializedDetached(int[] shape) {
|
public static INDArray createUninitializedDetached(int... shape) {
|
||||||
return createUninitializedDetached(shape, Nd4j.order());
|
return createUninitializedDetached(shape, Nd4j.order());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5210,7 +5280,7 @@ public class Nd4j {
|
||||||
* @param shape
|
* @param shape
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray createUninitializedDetached(long[] shape) {
|
public static INDArray createUninitializedDetached(long... shape) {
|
||||||
return createUninitializedDetached(shape, Nd4j.order());
|
return createUninitializedDetached(shape, Nd4j.order());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5448,7 +5518,7 @@ public class Nd4j {
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return the created array.
|
* @return the created array.
|
||||||
*/
|
*/
|
||||||
public static INDArray zeros(DataType dataType, long... shape) {
|
public static INDArray zeros(DataType dataType, @NonNull long... shape) {
|
||||||
return INSTANCE.create(dataType, shape, 'c', Nd4j.getMemoryManager().getCurrentWorkspace());
|
return INSTANCE.create(dataType, shape, 'c', Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5621,7 +5691,7 @@ public class Nd4j {
|
||||||
* @param shape Shape fo the array
|
* @param shape Shape fo the array
|
||||||
* @return the created ndarray
|
* @return the created ndarray
|
||||||
*/
|
*/
|
||||||
public static INDArray ones(DataType dataType, long... shape) {
|
public static INDArray ones(DataType dataType, @NonNull long... shape) {
|
||||||
if(shape.length == 0)
|
if(shape.length == 0)
|
||||||
return Nd4j.scalar(dataType, 1.0);
|
return Nd4j.scalar(dataType, 1.0);
|
||||||
INDArray ret = INSTANCE.createUninitialized(dataType, shape, Nd4j.order(), Nd4j.getMemoryManager().getCurrentWorkspace());
|
INDArray ret = INSTANCE.createUninitialized(dataType, shape, Nd4j.order(), Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||||
|
@ -5635,8 +5705,9 @@ public class Nd4j {
|
||||||
*
|
*
|
||||||
* @param arrs the first matrix to concat
|
* @param arrs the first matrix to concat
|
||||||
*/
|
*/
|
||||||
public static INDArray hstack(INDArray... arrs) {
|
public static INDArray hstack(@NonNull INDArray... arrs) {
|
||||||
return INSTANCE.hstack(arrs);
|
INDArray ret = INSTANCE.hstack(arrs);
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -5656,7 +5727,7 @@ public class Nd4j {
|
||||||
*
|
*
|
||||||
* @param arrs Arrays to vstack
|
* @param arrs Arrays to vstack
|
||||||
*/
|
*/
|
||||||
public static INDArray vstack(INDArray... arrs) {
|
public static INDArray vstack(@NonNull INDArray... arrs) {
|
||||||
Preconditions.checkState(arrs != null && arrs.length > 0, "No input specified to vstack (null or length 0)");
|
Preconditions.checkState(arrs != null && arrs.length > 0, "No input specified to vstack (null or length 0)");
|
||||||
if(arrs[0].rank() == 1){
|
if(arrs[0].rank() == 1){
|
||||||
//Edge case: vstack rank 1 arrays - gives rank 2... vstack([3],[3]) -> [2,3]
|
//Edge case: vstack rank 1 arrays - gives rank 2... vstack([3],[3]) -> [2,3]
|
||||||
|
@ -5758,7 +5829,7 @@ public class Nd4j {
|
||||||
* @param arrays
|
* @param arrays
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray accumulate(INDArray... arrays) {
|
public static INDArray accumulate(@NonNull INDArray... arrays) {
|
||||||
if (arrays == null|| arrays.length == 0)
|
if (arrays == null|| arrays.length == 0)
|
||||||
throw new ND4JIllegalStateException("Input for accumulation is null or empty");
|
throw new ND4JIllegalStateException("Input for accumulation is null or empty");
|
||||||
|
|
||||||
|
@ -5798,7 +5869,7 @@ public class Nd4j {
|
||||||
* @param indexes indexes from source array
|
* @param indexes indexes from source array
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray pullRows(INDArray source, int sourceDimension, int[] indexes) {
|
public static INDArray pullRows(INDArray source, int sourceDimension, @NonNull int... indexes) {
|
||||||
return pullRows(source, sourceDimension, indexes, Nd4j.order());
|
return pullRows(source, sourceDimension, indexes, Nd4j.order());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5847,7 +5918,7 @@ public class Nd4j {
|
||||||
* @param indexes indexes from source array
|
* @param indexes indexes from source array
|
||||||
* @return Destination array with specified tensors
|
* @return Destination array with specified tensors
|
||||||
*/
|
*/
|
||||||
public static INDArray pullRows(INDArray source, INDArray destination, int sourceDimension, int[] indexes){
|
public static INDArray pullRows(INDArray source, INDArray destination, int sourceDimension, @NonNull int... indexes){
|
||||||
if (sourceDimension >= source.rank())
|
if (sourceDimension >= source.rank())
|
||||||
throw new IllegalStateException("Source dimension can't be higher the rank of source tensor");
|
throw new IllegalStateException("Source dimension can't be higher the rank of source tensor");
|
||||||
|
|
||||||
|
@ -5880,7 +5951,7 @@ public class Nd4j {
|
||||||
* @return Output array
|
* @return Output array
|
||||||
* @see #concat(int, INDArray...)
|
* @see #concat(int, INDArray...)
|
||||||
*/
|
*/
|
||||||
public static INDArray stack(int axis, INDArray... values){
|
public static INDArray stack(int axis, @NonNull INDArray... values){
|
||||||
Preconditions.checkArgument(values != null && values.length > 0, "No inputs: %s", values);
|
Preconditions.checkArgument(values != null && values.length > 0, "No inputs: %s", values);
|
||||||
Preconditions.checkState(axis >= -(values[0].rank()+1) && axis < values[0].rank()+1, "Invalid axis: must be between " +
|
Preconditions.checkState(axis >= -(values[0].rank()+1) && axis < values[0].rank()+1, "Invalid axis: must be between " +
|
||||||
"%s (inclusive) and %s (exclusive) for rank %s input, got %s", -(values[0].rank()+1), values[0].rank()+1,
|
"%s (inclusive) and %s (exclusive) for rank %s input, got %s", -(values[0].rank()+1), values[0].rank()+1,
|
||||||
|
@ -5902,7 +5973,7 @@ public class Nd4j {
|
||||||
* the ndarray shapes save the dimension shape specified
|
* the ndarray shapes save the dimension shape specified
|
||||||
* which is then the sum of the sizes along that dimension
|
* which is then the sum of the sizes along that dimension
|
||||||
*/
|
*/
|
||||||
public static INDArray concat(int dimension, INDArray... toConcat) {
|
public static INDArray concat(int dimension, @NonNull INDArray... toConcat) {
|
||||||
if(dimension < 0) {
|
if(dimension < 0) {
|
||||||
dimension += toConcat[0].rank();
|
dimension += toConcat[0].rank();
|
||||||
}
|
}
|
||||||
|
@ -5919,8 +5990,9 @@ public class Nd4j {
|
||||||
* @param toConcat
|
* @param toConcat
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray specialConcat(int dimension, INDArray... toConcat) {
|
public static INDArray specialConcat(int dimension, @NonNull INDArray... toConcat) {
|
||||||
return INSTANCE.specialConcat(dimension, toConcat);
|
INDArray ret = INSTANCE.specialConcat(dimension, toConcat);
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -5950,7 +6022,7 @@ public class Nd4j {
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return an ndarray with ones filled in
|
* @return an ndarray with ones filled in
|
||||||
*/
|
*/
|
||||||
public static INDArray zeros(int... shape) {
|
public static INDArray zeros(@NonNull int... shape) {
|
||||||
return Nd4j.create(shape);
|
return Nd4j.create(shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5961,7 +6033,7 @@ public class Nd4j {
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return an ndarray with ones filled in
|
* @return an ndarray with ones filled in
|
||||||
*/
|
*/
|
||||||
public static INDArray zeros(long... shape) {
|
public static INDArray zeros(@NonNull long... shape) {
|
||||||
return Nd4j.create(shape);
|
return Nd4j.create(shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6085,7 +6157,7 @@ public class Nd4j {
|
||||||
* @return the strides for the given shape
|
* @return the strides for the given shape
|
||||||
* and order specified by NDArrays.order()
|
* and order specified by NDArrays.order()
|
||||||
*/
|
*/
|
||||||
public static int[] getStrides(int[] shape) {
|
public static int[] getStrides(@NonNull int... shape) {
|
||||||
return getStrides(shape, Nd4j.order());
|
return getStrides(shape, Nd4j.order());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6097,7 +6169,7 @@ public class Nd4j {
|
||||||
* @return the strides for the given shape
|
* @return the strides for the given shape
|
||||||
* and order specified by NDArrays.order()
|
* and order specified by NDArrays.order()
|
||||||
*/
|
*/
|
||||||
public static long[] getStrides(long[] shape) {
|
public static long[] getStrides(@NonNull long... shape) {
|
||||||
return getStrides(shape, Nd4j.order());
|
return getStrides(shape, Nd4j.order());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6108,7 +6180,7 @@ public class Nd4j {
|
||||||
* @param repeat the shape to repeat
|
* @param repeat the shape to repeat
|
||||||
* @return the tiled ndarray
|
* @return the tiled ndarray
|
||||||
*/
|
*/
|
||||||
public static INDArray tile(INDArray tile, int... repeat) {
|
public static INDArray tile(INDArray tile, @NonNull int... repeat) {
|
||||||
int d = repeat.length;
|
int d = repeat.length;
|
||||||
long[] shape = ArrayUtil.copy(tile.shape());
|
long[] shape = ArrayUtil.copy(tile.shape());
|
||||||
long n = Math.max(tile.length(), 1);
|
long n = Math.max(tile.length(), 1);
|
||||||
|
@ -6157,11 +6229,11 @@ public class Nd4j {
|
||||||
* @return the strides for the given shape
|
* @return the strides for the given shape
|
||||||
* and order specified by NDArrays.order()
|
* and order specified by NDArrays.order()
|
||||||
*/
|
*/
|
||||||
public static int[] getComplexStrides(int[] shape) {
|
public static int[] getComplexStrides(@NonNull int... shape) {
|
||||||
return getComplexStrides(shape, Nd4j.order());
|
return getComplexStrides(shape, Nd4j.order());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static long[] getComplexStrides(long[] shape) {
|
public static long[] getComplexStrides(@NonNull long... shape) {
|
||||||
return getComplexStrides(shape, Nd4j.order());
|
return getComplexStrides(shape, Nd4j.order());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6501,7 +6573,7 @@ public class Nd4j {
|
||||||
*
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray pile(INDArray... arrays) {
|
public static INDArray pile(@NonNull INDArray... arrays) {
|
||||||
// if we have vectors as input, it's just vstack use case
|
// if we have vectors as input, it's just vstack use case
|
||||||
|
|
||||||
long[] shape = arrays[0].shape();
|
long[] shape = arrays[0].shape();
|
||||||
|
@ -6536,7 +6608,7 @@ public class Nd4j {
|
||||||
* @param dimensions
|
* @param dimensions
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray[] tear(INDArray tensor, int... dimensions) {
|
public static INDArray[] tear(INDArray tensor, @NonNull int... dimensions) {
|
||||||
if (dimensions.length >= tensor.rank())
|
if (dimensions.length >= tensor.rank())
|
||||||
throw new ND4JIllegalStateException("Target dimensions number should be less tensor rank");
|
throw new ND4JIllegalStateException("Target dimensions number should be less tensor rank");
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,8 @@ import java.util.Locale;
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
* @author Susan Eraly
|
* @author Susan Eraly
|
||||||
*/
|
*/
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
public class NDArrayStrings {
|
public class NDArrayStrings {
|
||||||
|
|
||||||
public static final String EMPTY_ARRAY_STR = "[]";
|
public static final String EMPTY_ARRAY_STR = "[]";
|
||||||
|
@ -55,7 +57,7 @@ public class NDArrayStrings {
|
||||||
@Setter @Getter
|
@Setter @Getter
|
||||||
private static long maxPrintElements = DEFAULT_MAX_PRINT_ELEMENTS;
|
private static long maxPrintElements = DEFAULT_MAX_PRINT_ELEMENTS;
|
||||||
|
|
||||||
|
private long localMaxPrintElements = maxPrintElements;
|
||||||
private String colSep = ",";
|
private String colSep = ",";
|
||||||
private String newLineSep = ",";
|
private String newLineSep = ",";
|
||||||
private int padding = 7;
|
private int padding = 7;
|
||||||
|
@ -82,6 +84,38 @@ public class NDArrayStrings {
|
||||||
this(",", precision);
|
this(",", precision);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public NDArrayStrings(long maxElements, int precision) {
|
||||||
|
this(",", precision);
|
||||||
|
this.localMaxPrintElements = maxElements;
|
||||||
|
}
|
||||||
|
|
||||||
|
public NDArrayStrings(long maxElements) {
|
||||||
|
this();
|
||||||
|
this.localMaxPrintElements = maxElements;
|
||||||
|
}
|
||||||
|
|
||||||
|
public NDArrayStrings(long maxElements, boolean forceSummarize, int precision) {
|
||||||
|
this(",", precision);
|
||||||
|
if(forceSummarize)
|
||||||
|
localMaxPrintElements = 0;
|
||||||
|
else
|
||||||
|
localMaxPrintElements = maxElements;
|
||||||
|
}
|
||||||
|
|
||||||
|
public NDArrayStrings(boolean forceSummarize, int precision) {
|
||||||
|
this(",", precision);
|
||||||
|
if(forceSummarize)
|
||||||
|
localMaxPrintElements = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
public NDArrayStrings(boolean forceSummarize) {
|
||||||
|
this(",", 4);
|
||||||
|
if(forceSummarize)
|
||||||
|
localMaxPrintElements = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specify a delimiter for elements in columns for 2d arrays (or in the rank-1th dimension in higher order arrays)
|
* Specify a delimiter for elements in columns for 2d arrays (or in the rank-1th dimension in higher order arrays)
|
||||||
|
@ -93,13 +127,17 @@ public class NDArrayStrings {
|
||||||
public NDArrayStrings(String colSep, int precision) {
|
public NDArrayStrings(String colSep, int precision) {
|
||||||
this.colSep = colSep;
|
this.colSep = colSep;
|
||||||
if (!colSep.replaceAll("\\s", "").equals(",")) this.newLineSep = "";
|
if (!colSep.replaceAll("\\s", "").equals(",")) this.newLineSep = "";
|
||||||
this.precision = precision;
|
StringBuilder decFormatNum = new StringBuilder("0.");
|
||||||
String decFormatNum = "0.";
|
|
||||||
while (precision > 0) {
|
int prec = Math.abs(precision);
|
||||||
decFormatNum += "0";
|
this.precision = prec;
|
||||||
precision -= 1;
|
boolean useHash = precision < 0;
|
||||||
|
|
||||||
|
while (prec > 0) {
|
||||||
|
decFormatNum.append(useHash ? "#" : "0");
|
||||||
|
prec -= 1;
|
||||||
}
|
}
|
||||||
this.decimalFormat = localeIndifferentDecimalFormat(decFormatNum);
|
this.decimalFormat = localeIndifferentDecimalFormat(decFormatNum.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -147,7 +185,7 @@ public class NDArrayStrings {
|
||||||
if (this.scientificFormat.length() + 2 > this.padding) this.padding = this.scientificFormat.length() + 2;
|
if (this.scientificFormat.length() + 2 > this.padding) this.padding = this.scientificFormat.length() + 2;
|
||||||
this.maxToPrintWithoutSwitching = Math.pow(10,this.precision);
|
this.maxToPrintWithoutSwitching = Math.pow(10,this.precision);
|
||||||
this.minToPrintWithoutSwitching = 1.0/(this.maxToPrintWithoutSwitching);
|
this.minToPrintWithoutSwitching = 1.0/(this.maxToPrintWithoutSwitching);
|
||||||
return format(arr, 0, summarize && arr.length() > maxPrintElements);
|
return format(arr, 0, summarize && arr.length() > localMaxPrintElements);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String format(INDArray arr, int offset, boolean summarize) {
|
private String format(INDArray arr, int offset, boolean summarize) {
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
import java.io.ByteArrayInputStream;
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.io.ObjectInputStream;
|
||||||
|
import java.io.ObjectOutputStream;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import lombok.val;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
|
@Slf4j
|
||||||
|
public class ToStringTest extends BaseNd4jTest {
|
||||||
|
public ToStringTest(Nd4jBackend backend) {
|
||||||
|
super(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testToString() throws Exception {
|
||||||
|
assertEquals("[ 1, 2, 3]",
|
||||||
|
Nd4j.createFromArray(1, 2, 3).toString());
|
||||||
|
|
||||||
|
assertEquals("[ 1, 2, 3 ... 6 7, 8]",
|
||||||
|
Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(1000, false, 2));
|
||||||
|
|
||||||
|
assertEquals("[ 1.132, 2.644, 3.234]",
|
||||||
|
Nd4j.createFromArray(1.132414, 2.64356456, 3.234234).toString(1000, false, 3));
|
||||||
|
|
||||||
|
assertEquals("[ 1.132414, 2.64356456, 3.25345234]",
|
||||||
|
Nd4j.createFromArray(1.132414, 2.64356456, 3.25345234).toStringFull());
|
||||||
|
|
||||||
|
assertEquals("[ 1, 2, 3 ... 6 7, 8]",
|
||||||
|
Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(100, true, 1));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public char ordering() {
|
||||||
|
return 'c';
|
||||||
|
}
|
||||||
|
}
|
|
@ -124,6 +124,36 @@ public enum DataType {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return the max number of significant decimal digits
|
||||||
|
*/
|
||||||
|
public int precision(){
|
||||||
|
switch (this){
|
||||||
|
case DOUBLE:
|
||||||
|
return 17;
|
||||||
|
case FLOAT:
|
||||||
|
return 9;
|
||||||
|
case HALF:
|
||||||
|
return 5;
|
||||||
|
case BFLOAT16:
|
||||||
|
return 4;
|
||||||
|
case LONG:
|
||||||
|
case INT:
|
||||||
|
case SHORT:
|
||||||
|
case BYTE:
|
||||||
|
case UBYTE:
|
||||||
|
case BOOL:
|
||||||
|
case UTF8:
|
||||||
|
case COMPRESSED:
|
||||||
|
case UINT16:
|
||||||
|
case UINT32:
|
||||||
|
case UINT64:
|
||||||
|
case UNKNOWN:
|
||||||
|
default:
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return For fixed-width types, this returns the number of bytes per array element
|
* @return For fixed-width types, this returns the number of bytes per array element
|
||||||
*/
|
*/
|
||||||
|
|
Loading…
Reference in New Issue