From 59cba587f4d99507c2019cef5eb6c477b9b0a8c4 Mon Sep 17 00:00:00 2001 From: Robert Altena Date: Thu, 15 Aug 2019 13:50:52 +0900 Subject: [PATCH] Nd4j refactoring (#112) * refactoring Signed-off-by: Robert Altena * wip Signed-off-by: Robert Altena * wip Signed-off-by: Robert Altena * wip * fix: make test public. Signed-off-by: Robert Altena * make test public. Signed-off-by: Robert Altena * fixes read refactoring. Signed-off-by: Robert Altena --- .../java/org/nd4j/linalg/factory/Nd4j.java | 519 ++++++------------ .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 8 + .../org/nd4j/linalg/api/rng/RngTests.java | 2 +- 3 files changed, 169 insertions(+), 360 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 429782c3e..58d7a113e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -16,7 +16,6 @@ package org.nd4j.linalg.factory; -import com.google.common.base.Function; import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; import lombok.NonNull; @@ -85,6 +84,7 @@ import org.nd4j.linalg.memory.deallocation.DeallocatorService; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.string.NDArrayStrings; import org.nd4j.linalg.util.ArrayUtil; +import org.nd4j.linalg.util.LongUtils; import org.nd4j.tools.PropertyParser; import org.nd4j.versioncheck.VersionCheck; @@ -1756,14 +1756,12 @@ public class Nd4j { index[j] = (double) j; } - /** + /* * Inject a comparator that sorts indices relative to * the actual values in the data. * This allows us to retain the indices * and how they were rearranged. */ - - Arrays.sort(index, new Comparator() { @Override public int compare(Double o1, Double o2) { @@ -2155,6 +2153,7 @@ public class Nd4j { * Defaults to scientific notation with 18 digits after the decimal * Use {@link #writeTxt(INDArray, String)} */ + @SuppressWarnings("unused") //backward compatibility. public static void writeTxt(INDArray write, String filePath, String split, int precision) { writeTxt(write,filePath); } @@ -2169,6 +2168,7 @@ public class Nd4j { * Defaults to scientific notation with 18 digits after the decimal * Use {@link #writeTxt(INDArray, String)} */ + @SuppressWarnings("unused") //backward compatibility. public static void writeTxt(INDArray write, String filePath, int precision) { writeTxt(write, filePath); } @@ -2182,6 +2182,7 @@ public class Nd4j { * @deprecated custom col and higher dimension separators are no longer supported; uses "," * Use {@link #writeTxt(INDArray, String)} */ + @SuppressWarnings("unused") public static void writeTxt(INDArray write, String filePath, String split) { writeTxt(write,filePath); } @@ -2206,15 +2207,13 @@ public class Nd4j { write = write.dup(); String format = "0.000000000000000000E0"; - return new StringBuilder() - .append("{\n") - .append("\"filefrom\": \"dl4j\",\n") - .append( "\"ordering\": \"").append(write.ordering()).append("\",\n") - .append("\"shape\":\t").append( java.util.Arrays.toString(write.shape())).append(",\n") - .append("\"data\":\n") - .append(new NDArrayStrings(",", format).format(write, false)) - .append("\n}\n") - .toString(); + return "{\n" + + "\"filefrom\": \"dl4j\",\n" + + "\"ordering\": \"" + write.ordering() + "\",\n" + + "\"shape\":\t" + Arrays.toString(write.shape()) + ",\n" + + "\"data\":\n" + + new NDArrayStrings(",", format).format(write, false) + + "\n}\n"; } @@ -2242,8 +2241,7 @@ public class Nd4j { ByteArrayOutputStream bos = new ByteArrayOutputStream((int) (arr.length() * arr.data().getElementSize())); DataOutputStream dos = new DataOutputStream(bos); write(arr, dos); - byte[] ret = bos.toByteArray(); - return ret; + return bos.toByteArray(); } /** @@ -2251,10 +2249,9 @@ public class Nd4j { * @param arr the array to read from * @return the deserialized ndarray */ - public static INDArray fromByteArray(@NonNull byte[] arr) throws IOException { + public static INDArray fromByteArray(@NonNull byte[] arr) { ByteArrayInputStream bis = new ByteArrayInputStream(arr); - INDArray ret = read(bis); - return ret; + return read(bis); } /** @@ -2277,6 +2274,7 @@ public class Nd4j { * @param charset the charset * @return the deserialized array. */ + @SuppressWarnings("WeakerAccess") //really should add testing for the method. public static INDArray readNumpy(@NonNull DataType dataType, @NonNull InputStream filePath, @NonNull String split, @NonNull Charset charset) throws IOException { BufferedReader reader = new BufferedReader(new InputStreamReader(filePath, charset)); String line; @@ -2369,7 +2367,7 @@ public class Nd4j { * * See {@link #read(DataInputStream)} */ - public static INDArray read(InputStream reader) throws IOException { + public static INDArray read(InputStream reader) { return read(new DataInputStream(reader)); } @@ -2379,6 +2377,7 @@ public class Nd4j { * @param ndarray the input stream ndarray * @return NDArray */ + @SuppressWarnings("WeakerAccess") public static INDArray readTxtString(InputStream ndarray) { String sep = ","; /* @@ -2447,6 +2446,7 @@ public class Nd4j { String[] entries = line.replace("\\],", "").replaceAll("]", "").replaceAll("\\[", "").split(sep); if (rank == 0) { try { + //noinspection ConstantConditions newArr.addi((format.parse(entries[0])).doubleValue()); } catch (ParseException e) { e.printStackTrace(); @@ -2558,17 +2558,17 @@ public class Nd4j { public static INDArray read(DataInputStream dis) { val headerShape = BaseDataBuffer.readHeader(dis); + //noinspection UnnecessaryUnboxing var shapeInformation = Nd4j.createBufferDetached(new long[]{headerShape.getMiddle().longValue()}, headerShape.getRight()); shapeInformation.read(dis, headerShape.getLeft(), headerShape.getMiddle(), headerShape.getThird()); - val length = Shape.length(shapeInformation); - DataType type = null; + DataType type; DataBuffer data = null; val headerData = BaseDataBuffer.readHeader(dis); try { // current version contains dtype in extras data = CompressedDataBuffer.readUnknown(dis, headerData.getFirst(), headerData.getMiddle(), headerData.getRight()); - type = ArrayOptionsHelper.dataType(shapeInformation.asLong()); + ArrayOptionsHelper.dataType(shapeInformation.asLong()); } catch (ND4JUnknownDataTypeException e) { // manually setting data type type = headerData.getRight(); @@ -2731,16 +2731,7 @@ public class Nd4j { public static INDArray choice(INDArray source, INDArray probs, INDArray target) { return choice(source, probs, target, Nd4j.getRandom()); } - - /** - * - * - * @param source - * @param probs - * @param numSamples - * @return - */ - + // @see tag works well here. /** * This method returns new INDArray instance, sampled from Source array with probabilities given in Probs. @@ -3749,14 +3740,13 @@ public class Nd4j { * This method creates new 0D INDArray, aka scalar. * * PLEASE NOTE: Temporary method, added to ensure backward compatibility - * @param scalar - * @return - * @deprecated Use Nd4j.scalar methods, such as {@link #scalar(double)} or {@link #scalar(DataType, Number)} + * @param scalar data for INDArray. + * @return new INDArray + * * @deprecated Use Nd4j.scalar methods, such as {@link #scalar(double)} or {@link #scalar(DataType, Number)} */ @Deprecated public static INDArray trueScalar(Number scalar) { - val ret = INSTANCE.trueScalar(scalar); - return ret; + return INSTANCE.trueScalar(scalar); } /** @@ -3764,8 +3754,7 @@ public class Nd4j { */ @Deprecated public static INDArray trueVector(boolean[] data) { - val ret = INSTANCE.trueVector(data); - return ret; + return INSTANCE.trueVector(data); } /** @@ -3773,8 +3762,7 @@ public class Nd4j { */ @Deprecated public static INDArray trueVector(long[] data) { - val ret = INSTANCE.trueVector(data); - return ret; + return INSTANCE.trueVector(data); } /** @@ -3782,8 +3770,7 @@ public class Nd4j { */ @Deprecated public static INDArray trueVector(int[] data) { - val ret = INSTANCE.trueVector(data); - return ret; + return INSTANCE.trueVector(data); } /** @@ -3791,8 +3778,7 @@ public class Nd4j { */ @Deprecated public static INDArray trueVector(float[] data) { - val ret = INSTANCE.trueVector(data); - return ret; + return INSTANCE.trueVector(data); } /** @@ -3800,8 +3786,7 @@ public class Nd4j { */ @Deprecated public static INDArray trueVector(double[] data) { - val ret = INSTANCE.trueVector(data); - return ret; + return INSTANCE.trueVector(data); } /** @@ -3820,7 +3805,7 @@ public class Nd4j { */ public static INDArray empty(DataType type) { if(EMPTY_ARRAYS[type.ordinal()] == null){ - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()){ + try(MemoryWorkspace ignored = Nd4j.getMemoryManager().scopeOutOfWorkspaces()){ val ret = INSTANCE.empty(type); EMPTY_ARRAYS[type.ordinal()] = ret; } @@ -3844,11 +3829,8 @@ public class Nd4j { if (shape[0] != data.length) throw new ND4JIllegalStateException("Shape of the new array doesn't match data length"); } - - checkShapeValues(data.length, shape); - - INDArray ret = INSTANCE.create(data, shape); - return ret; + checkShapeValues(data.length, LongUtils.toLongs(shape)); + return INSTANCE.create(data, shape); } /** @@ -3858,16 +3840,8 @@ public class Nd4j { if (shape.length == 0 && data.length == 1) { return scalar(data[0]); } - - if (shape.length == 1) { - if (shape[0] != data.length) - throw new ND4JIllegalStateException("Shape of the new array doesn't match data length"); - } - - checkShapeValues(data.length, shape); - - INDArray ret = INSTANCE.create(data, shape, Nd4j.getStrides(shape, Nd4j.order()), DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace()); - return ret; + commonCheckCreate(data.length, shape); + return INSTANCE.create(data, shape, Nd4j.getStrides(shape, Nd4j.order()), DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace()); } /** @@ -3877,16 +3851,8 @@ public class Nd4j { if (shape.length == 0 && data.length == 1) { return scalar(data[0]); } - - if (shape.length == 1) { - if (shape[0] != data.length) - throw new ND4JIllegalStateException("Shape of the new array doesn't match data length"); - } - - checkShapeValues(data.length, shape); - - INDArray ret = INSTANCE.create(data, shape, Nd4j.getStrides(shape, Nd4j.order()), DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); - return ret; + commonCheckCreate(data.length, shape); + return INSTANCE.create(data, shape, Nd4j.getStrides(shape, Nd4j.order()), DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); } /** @@ -3897,16 +3863,9 @@ public class Nd4j { * @return the created ndarray */ public static INDArray create(double[] data, int... shape) { - if (shape.length == 1) { - if (shape[0] != data.length) - throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length); - } - - checkShapeValues(data.length, shape); - + commonCheckCreate(data.length, LongUtils.toLongs(shape)); val lshape = ArrayUtil.toLongArray(shape); - INDArray ret = INSTANCE.create(data, lshape, Nd4j.getStrides(lshape, Nd4j.order()), DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); - return ret; + return INSTANCE.create(data, lshape, Nd4j.getStrides(lshape, Nd4j.order()), DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); } /** @@ -3920,32 +3879,26 @@ public class Nd4j { * @return the created ndarray. */ public static INDArray create(double[] data, int[] shape, long offset, char ordering) { - if (shape.length == 1) { - if (shape[0] != data.length) + commonCheckCreate(data.length, LongUtils.toLongs(shape)); + return INSTANCE.create(data, shape, Nd4j.getStrides(shape, ordering), offset, ordering); + } + + private static void commonCheckCreate( int dataLength, long[] shape){ + if (shape.length== 1) { + if (shape[0] != dataLength) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) - + " doesn't match data length: " + data.length); + + " doesn't match data length: " + dataLength); } - checkShapeValues(data.length, shape); - - INDArray ret = INSTANCE.create(data, shape, Nd4j.getStrides(shape, ordering), offset, ordering); - return ret; + checkShapeValues(dataLength, shape); } /** * See {@link #create(double[], int[], long, char )} */ public static INDArray create(double[] data, long[] shape, long offset, char ordering) { - if (shape.length == 1) { - if (shape[0] != data.length) - throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) - + " doesn't match data length: " + data.length); - } - - checkShapeValues(data.length, shape); - - INDArray ret = INSTANCE.create(data, shape, Nd4j.getStrides(shape, ordering), offset, ordering); - return ret; + commonCheckCreate(data.length, shape); + return INSTANCE.create(data, shape, Nd4j.getStrides(shape, ordering), offset, ordering); } /** @@ -3958,16 +3911,8 @@ public class Nd4j { * @return the instance */ public static INDArray create(float[] data, int[] shape, int[] stride, long offset) { - if (shape.length == 1) { - if (shape[0] != data.length) - throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) - + " doesn't match data length: " + data.length); - } - - checkShapeValues(data.length, shape); - - INDArray ret = INSTANCE.create(data, shape, stride, offset); - return ret; + commonCheckCreate(data.length, LongUtils.toLongs(shape)); + return INSTANCE.create(data, shape, stride, offset); } /** @@ -3979,9 +3924,7 @@ public class Nd4j { */ public static INDArray create(List list, int... shape) { checkShapeValues(shape); - - INDArray ret = INSTANCE.create(list, shape); - return ret; + return INSTANCE.create(list, shape); } /** @@ -3989,9 +3932,7 @@ public class Nd4j { */ public static INDArray create(List list, long... shape) { checkShapeValues(shape); - - INDArray ret = INSTANCE.create(list, shape); - return ret; + return INSTANCE.create(list, shape); } /** @@ -4027,10 +3968,7 @@ public class Nd4j { */ public static INDArray create(int[] shape, int[] stride, long offset) { checkShapeValues(shape); - - INDArray ret = INSTANCE.create(shape, stride, offset); - return ret; - + return INSTANCE.create(shape, stride, offset); } /** @@ -4260,34 +4198,21 @@ public class Nd4j { /** * Create an array withgiven shape and ordering based on a java double array. * @param data java array used for initialisation. Must have at least the number of elements required. - * @@param shape desired shape of new array. + * @param shape desired shape of new array. * @param ordering Fortran 'f' or C/C++ 'c' ordering. * @return the created ndarray. */ public static INDArray create(double[] data, int[] shape, char ordering) { - //TODO: duplicate code and issue #8013 - if (shape.length == 1) { - if (shape[0] != data.length) - throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) - + " doesn't match data length: " + data.length); - } - - checkShapeValues(data.length, shape); + commonCheckCreate(data.length, LongUtils.toLongs(shape)); val lshape = ArrayUtil.toLongArray(shape); - return INSTANCE.create(data, lshape, Nd4j.getStrides(lshape, ordering), ordering, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); + return INSTANCE.create(data, lshape, Nd4j.getStrides(lshape, ordering), ordering, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); } /** * See {@link #create(double[], int[], char)} */ public static INDArray create(float[] data, int[] shape, char ordering) { - if (shape.length == 1) { - if (shape[0] != data.length) - throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) - + " doesn't match data length: " + data.length); - } - - checkShapeValues(data.length, shape); + commonCheckCreate(data.length, LongUtils.toLongs(shape)); return INSTANCE.create(data, shape, ordering); } @@ -4307,22 +4232,6 @@ public class Nd4j { return INSTANCE.create(data, shape, Nd4j.getStrides(shape, ordering), ordering, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); } - /** - * Creates an ndarray with the specified shape - * TODO: unused method. (only used by the zeros method in this class) - * - * @param rows the rows of the ndarray - * @param columns the columns of the ndarray - * @param stride the stride for the ndarray - * @param offset the offset of the ndarray - * @return the instance - */ - public static INDArray create(int rows, int columns, int[] stride, long offset, char ordering) { - int[] shape = new int[]{rows, columns}; - checkShapeValues(shape); - return INSTANCE.create(shape, stride, offset, ordering); - } - /** * Creates an ndarray with the specified shape * @@ -4469,11 +4378,9 @@ public class Nd4j { return INSTANCE.create(dataType, shape, ordering, Nd4j.getMemoryManager().getCurrentWorkspace()); } - - // TODO: Leaving these until #8028 is fixed. /** - * - * @param shape + * Throws exception on negative shape values. + * @param shape to check */ public static void checkShapeValues(long... shape) { for (long e: shape) { @@ -4483,30 +4390,13 @@ public class Nd4j { } } - // TODO: Leaving these until #8028 is fixed. - /** - * - * @param shape - */ - public static void checkShapeValues(int... shape) { - for (int e: shape) { - if (e < 0) - throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape) - + " contains dimension size values < 0 (all dimensions must be 0 or more)"); - } + // made private as it is only used for internal checks. + private static void checkShapeValues(int... shape) { + checkShapeValues(LongUtils.toLongs(shape)); } - protected static void checkShapeValues(int length, int... shape) { + private static void checkShapeValues(int length, long... shape) { checkShapeValues(shape); - - if (ArrayUtil.prodLong(shape) != length && !(length == 1 && shape.length == 0)) - throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) - + " doesn't match data length: " + length + " - prod(shape) must equal the number of values provided"); - } - - protected static void checkShapeValues(int length, long... shape) { - checkShapeValues(shape); - if (ArrayUtil.prodLong(shape) != length && !(length == 1 && shape.length == 0)) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + length + " - prod(shape) must equal the number of values provided"); @@ -4604,8 +4494,8 @@ public class Nd4j { * * PLEASE NOTE: Do not use this method unless you're 100% sure why you use it. * - * @param length - * @return + * @param length length of array to create + * @return the created INDArray */ public static INDArray createUninitialized(long length) { long[] shape = new long[] {length}; @@ -4619,6 +4509,7 @@ public class Nd4j { * @param shape the shape of the array. * @return the created detached array. */ + @SuppressWarnings("WeakerAccess") // For now. If part of public API it will need testing. public static INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape){ return INSTANCE.createUninitializedDetached(dataType, ordering, shape); } @@ -4738,6 +4629,7 @@ public class Nd4j { * @param type data type * @return the created ndarray */ + @SuppressWarnings("Duplicates") public static INDArray valueArrayOf(long[] shape, double value, DataType type) { if (shape.length == 0) return scalar(type, value); @@ -4752,6 +4644,7 @@ public class Nd4j { /** * See {@link #valueArrayOf(long[], double, DataType)} */ + @SuppressWarnings("Duplicates") public static INDArray valueArrayOf(long[] shape, long value, DataType type) { if (shape.length == 0) return scalar(type, value); @@ -4774,8 +4667,7 @@ public class Nd4j { * @return the created ndarray */ public static INDArray valueArrayOf(long num, double value) { - INDArray ret = INSTANCE.valueArrayOf(new long[] {num}, value); - return ret; + return INSTANCE.valueArrayOf(new long[] {num}, value); } /** @@ -4789,8 +4681,7 @@ public class Nd4j { * @return the created ndarray */ public static INDArray valueArrayOf(long rows, long columns, double value) { - INDArray ret = INSTANCE.valueArrayOf(rows, columns, value); - return ret; + return INSTANCE.valueArrayOf(rows, columns, value); } /** @@ -4801,8 +4692,7 @@ public class Nd4j { * @return the created ndarray */ public static INDArray ones(int rows, int columns) { - INDArray ret = INSTANCE.ones(rows, columns); - return ret; + return INSTANCE.ones(rows, columns); } /** @@ -4860,8 +4750,7 @@ public class Nd4j { * @param arrs the first matrix to concat */ public static INDArray hstack(@NonNull INDArray... arrs) { - INDArray ret = INSTANCE.hstack(arrs); - return ret; + return INSTANCE.hstack(arrs); } /** @@ -4883,6 +4772,7 @@ public class Nd4j { */ public static INDArray vstack(@NonNull INDArray... arrs) { Preconditions.checkState(arrs != null && arrs.length > 0, "No input specified to vstack (null or length 0)"); + //noinspection ConstantConditions if(arrs[0].rank() == 1){ //Edge case: vstack rank 1 arrays - gives rank 2... vstack([3],[3]) -> [2,3] return pile(arrs); @@ -4901,24 +4791,12 @@ public class Nd4j { return vstack(arrays); } - // TODO: unused method /** * This method averages input arrays, and returns averaged array. * On top of that, averaged array is propagated to all input arrays * - * @param arrays - * @return - */ - public static INDArray averageAndPropagate(INDArray target, INDArray[] arrays) { - return INSTANCE.average(target, arrays); - } - - /** - * This method averages input arrays, and returns averaged array. - * On top of that, averaged array is propagated to all input arrays - * - * @param arrays - * @return + * @param arrays arrays to average + * @return averaged arrays */ public static INDArray averageAndPropagate(INDArray[] arrays) { return INSTANCE.average(arrays); @@ -4929,8 +4807,8 @@ public class Nd4j { * This method averages input arrays, and returns averaged array. * On top of that, averaged array is propagated to all input arrays * - * @param arrays - * @return + * @param arrays arrays to average + * @return averaged arrays */ public static INDArray averageAndPropagate(Collection arrays) { return INSTANCE.average(arrays); @@ -4940,20 +4818,19 @@ public class Nd4j { * This method averages input arrays, and returns averaged array. * On top of that, averaged array is propagated to all input arrays * - * @param arrays - * @return + * @param arrays arrays to average + * @return averaged arrays */ public static INDArray averageAndPropagate(INDArray target, Collection arrays) { return INSTANCE.average(target, arrays); } - - /** * Reshapes an ndarray to remove leading 1s * @param toStrip the ndarray to newShapeNoCopy * @return the reshaped ndarray */ + @SuppressWarnings("WeakerAccess") // Needs tests if part of public API. public static INDArray stripOnes(INDArray toStrip) { if (toStrip.isVector()) return toStrip; @@ -4966,8 +4843,8 @@ public class Nd4j { /** * This method sums given arrays and stores them to a new array * - * @param arrays - * @return + * @param arrays array to accumulate + * @return accumulated array. */ public static INDArray accumulate(@NonNull INDArray... arrays) { if (arrays == null|| arrays.length == 0) @@ -4979,9 +4856,9 @@ public class Nd4j { /** * This method sums given arrays and stores them to a given target array * - * @param target - * @param arrays - * @return + * @param target result array + * @param arrays arrays to sum + * @return result array */ public static INDArray accumulate(INDArray target, Collection arrays) { return accumulate(target, arrays.toArray(new INDArray[0])); @@ -4990,14 +4867,13 @@ public class Nd4j { /** * This method sums given arrays and stores them to a given target array * - * @param target - * @param arrays - * @return + * @param target result array + * @param arrays arrays to sum + * @return result array */ public static INDArray accumulate(INDArray target, INDArray[] arrays) { if (arrays == null|| arrays.length == 0) return target; - return factory().accumulate(target, arrays); } @@ -5007,7 +4883,7 @@ public class Nd4j { * @param source source tensor * @param sourceDimension dimension of source tensor * @param indexes indexes from source array - * @return + * @return result array */ public static INDArray pullRows(INDArray source, int sourceDimension, @NonNull int... indexes) { return pullRows(source, sourceDimension, indexes, Nd4j.order()); @@ -5022,8 +4898,9 @@ public class Nd4j { * @param source source tensor * @param sourceDimension dimension of source tensor * @param indexes indexes from source array - * @return + * @return concatenated array */ + @SuppressWarnings("Duplicates") public static INDArray pullRows(INDArray source, int sourceDimension, int[] indexes, char order) { if (sourceDimension >= source.rank()) throw new IllegalStateException("Source dimension can't be higher the rank of source tensor"); @@ -5042,9 +4919,7 @@ public class Nd4j { } Preconditions.checkArgument(source.rank() > 1, "pullRows() can't operate on 0D/1D arrays"); - - INDArray ret = INSTANCE.pullRows(source, sourceDimension, indexes, order); - return ret; + return INSTANCE.pullRows(source, sourceDimension, indexes, order); } /** @@ -5058,6 +4933,7 @@ public class Nd4j { * @param indexes indexes from source array * @return Destination array with specified tensors */ + @SuppressWarnings("Duplicates") public static INDArray pullRows(INDArray source, INDArray destination, int sourceDimension, @NonNull int... indexes){ if (sourceDimension >= source.rank()) throw new IllegalStateException("Source dimension can't be higher the rank of source tensor"); @@ -5074,8 +4950,7 @@ public class Nd4j { Preconditions.checkArgument(source.rank() > 1, "pullRows() can't operate on 0D/1D arrays"); - INDArray ret = INSTANCE.pullRows(source, destination, sourceDimension, indexes); - return ret; + return INSTANCE.pullRows(source, destination, sourceDimension, indexes); } /** @@ -5091,8 +4966,9 @@ public class Nd4j { * @return Output array * @see #concat(int, INDArray...) */ + @SuppressWarnings("ConstantConditions") public static INDArray stack(int axis, @NonNull INDArray... values){ - Preconditions.checkArgument(values != null && values.length > 0, "No inputs: %s", values); + Preconditions.checkArgument(values != null && values.length > 0, "No inputs: %s", (Object[]) values); Preconditions.checkState(axis >= -(values[0].rank()+1) && axis < values[0].rank()+1, "Invalid axis: must be between " + "%s (inclusive) and %s (exclusive) for rank %s input, got %s", -(values[0].rank()+1), values[0].rank()+1, values[0].rank(), axis); @@ -5126,13 +5002,12 @@ public class Nd4j { * * PLEASE NOTE: This method is special for GPU backend, it works on HOST side only. * - * @param dimension - * @param toConcat - * @return + * @param dimension dimension + * @param toConcat arrayts to concatenate + * @return concatenated arrays. */ public static INDArray specialConcat(int dimension, @NonNull INDArray... toConcat) { - INDArray ret = INSTANCE.specialConcat(dimension, toConcat); - return ret; + return INSTANCE.specialConcat(dimension, toConcat); } /** @@ -5143,9 +5018,7 @@ public class Nd4j { */ public static INDArray zeros(int[] shape, char order) { checkShapeValues(shape); - - INDArray ret = INSTANCE.create(shape, order); - return ret; + return INSTANCE.create(shape, order); } /** @@ -5184,9 +5057,7 @@ public class Nd4j { * @return an ndarray with ones filled in */ public static INDArray ones(@NonNull int... shape) { - if(shape.length == 0) - return Nd4j.scalar(dataType(), 1.0); - return INSTANCE.ones(shape); + return (shape.length == 0) ? Nd4j.scalar(dataType(), 1.0) : INSTANCE.ones(shape); } @@ -5216,6 +5087,7 @@ public class Nd4j { * @param value the value to initialize the scalar with * @return the created ndarray */ + @SuppressWarnings("deprecation") public static INDArray scalar(DataType dataType, Number value) { return INSTANCE.trueScalar(dataType, value); } @@ -5247,7 +5119,7 @@ public class Nd4j { * @return the scalar nd array */ public static INDArray scalar(boolean value) { - return INSTANCE.trueScalar(DataType.BOOL, value ? 1 : 0); + return scalar(DataType.BOOL, value ? 1 : 0); } /** @@ -5324,45 +5196,10 @@ public class Nd4j { return Nd4j.exec(new Tile(new INDArray[]{tile}, new INDArray[]{}, repeat))[0]; } - /** - * Get the strides for the given order and shape - * - * @param shape the shape of the array - * @param order the order to getScalar the strides for - * @return the strides for the given shape and order - */ - public static int[] getComplexStrides(int[] shape, char order) { - if (order == NDArrayFactory.FORTRAN) - return ArrayUtil.calcStridesFortran(shape, 2); - return ArrayUtil.calcStrides(shape, 2); - } - - public static long[] getComplexStrides(long[] shape, char order) { - if (order == NDArrayFactory.FORTRAN) - return ArrayUtil.calcStridesFortran(shape, 2); - return ArrayUtil.calcStrides(shape, 2); - } - - /** - * Get the strides based on the shape - * and NDArrays.order() - * - * @param shape the shape of the array - * @return the strides for the given shape - * and order specified by NDArrays.order() - */ - public static int[] getComplexStrides(@NonNull int... shape) { - return getComplexStrides(shape, Nd4j.order()); - } - - public static long[] getComplexStrides(@NonNull long... shape) { - return getComplexStrides(shape, Nd4j.order()); - } - /** * Initializes nd4j */ - public synchronized void initContext() { + private synchronized void initContext() { try { defaultFloatingPointDataType = new AtomicReference<>(); defaultFloatingPointDataType.set(DataType.FLOAT); @@ -5377,6 +5214,7 @@ public class Nd4j { * Initialize with the specific backend * @param backend the backend to initialize with */ + @SuppressWarnings({"unchecked", "Duplicates"}) public void initWithBackend(Nd4jBackend backend) { VersionCheck.checkVersions(); @@ -5557,7 +5395,7 @@ public class Nd4j { /** * - * @return + * @return Shape info provider */ public static ShapeInfoProvider getShapeInfoProvider() { return shapeInfoProvider; @@ -5565,7 +5403,7 @@ public class Nd4j { /** * - * @return + * @return Sparse shape info provider */ public static SparseInfoProvider getSparseInfoProvider() { return sparseInfoProvider; @@ -5573,7 +5411,7 @@ public class Nd4j { /** * - * @return + * @return constant handler */ public static ConstantHandler getConstantHandler() { return constantHandler; @@ -5581,7 +5419,7 @@ public class Nd4j { /** * - * @return + * @return affinity manager */ public static AffinityManager getAffinityManager() { return affinityManager; @@ -5589,7 +5427,7 @@ public class Nd4j { /** * - * @return + * @return NDArrayFactory */ public static NDArrayFactory getNDArrayFactory() { return INSTANCE; @@ -5600,7 +5438,7 @@ public class Nd4j { * suitable for NDArray compression/decompression * at runtime * - * @return + * @return BasicNDArrayCompressor instance */ public static BasicNDArrayCompressor getCompressor() { return BasicNDArrayCompressor.getInstance(); @@ -5608,16 +5446,12 @@ public class Nd4j { /** * This method returns backend-specific MemoryManager implementation, for low-level memory management - * @return + * @return MemoryManager */ public static MemoryManager getMemoryManager() { return memoryManager; } - public static INDArray typeConversion(INDArray array, DataTypeEx targetType) { - return null; - } - /** * This method returns sizeOf(currentDataType), in bytes * @@ -5633,7 +5467,7 @@ public class Nd4j { * This method returns size of element for specified dataType, in bytes * * @param dtype number of bytes per element - * @return + * @return element size */ public static int sizeOfDataType(DataType dtype) { switch (dtype) { @@ -5666,7 +5500,7 @@ public class Nd4j { * * PLEASE NOTE: Do not use this method, unless you have too. * - * @param reallyEnable + * @param reallyEnable fallback mode */ public static void enableFallbackMode(boolean reallyEnable) { fallbackMode.set(reallyEnable); @@ -5675,8 +5509,9 @@ public class Nd4j { /** * This method checks, if fallback mode was enabled. * - * @return + * @return fallback mode */ + @SuppressWarnings("BooleanMethodIsAlwaysInverted") public static boolean isFallbackModeEnabled() { return fallbackMode.get(); } @@ -5684,31 +5519,29 @@ public class Nd4j { /** * This method returns WorkspaceManager implementation to be used within this JVM process * - * @return + * @return WorkspaceManager */ public static MemoryWorkspaceManager getWorkspaceManager() { return workspaceManager; } /** - * This method stacks vertically examples with the same shape, increasing result dimensionality. I.e. if you provide bunch of 3D tensors, output will be 4D tensor. Alignment is always applied to axis 0. + * This method stacks vertically examples with the same shape, increasing result dimensionality. + * I.e. if you provide bunch of 3D tensors, output will be 4D tensor. Alignment is always applied to axis 0. * - * @return + * @param arrays arrays to stack + * @return stacked arrays */ public static INDArray pile(@NonNull INDArray... arrays) { // if we have vectors as input, it's just vstack use case long[] shape = arrays[0].shape(); + //noinspection deprecation long[] newShape = ArrayUtils.add(shape, 0, 1); - boolean shouldReshape = true; - List reshaped = new ArrayList<>(); for(INDArray array: arrays) { - if (!shouldReshape) - reshaped.add(array); - else - reshaped.add(array.reshape(array.ordering(), newShape)); + reshaped.add(array.reshape(array.ordering(), newShape)); } return Nd4j.vstack(reshaped); @@ -5717,7 +5550,8 @@ public class Nd4j { /** * This method stacks vertically examples with the same shape, increasing result dimensionality. I.e. if you provide bunch of 3D tensors, output will be 4D tensor. Alignment is always applied to axis 0. * - * @return + * @param arrays arrays to stack + * @return stacked array */ public static INDArray pile(@NonNull Collection arrays) { return pile(arrays.toArray(new INDArray[0])); @@ -5726,22 +5560,20 @@ public class Nd4j { /** * This method does the opposite to pile/vstack/hstack - it returns independent TAD copies along given dimensions * - * @param tensor - * @param dimensions - * @return + * @param tensor Array to tear + * @param dimensions dimensions + * @return Array copies */ public static INDArray[] tear(INDArray tensor, @NonNull int... dimensions) { if (dimensions.length >= tensor.rank()) throw new ND4JIllegalStateException("Target dimensions number should be less tensor rank"); - for (int e = 0; e < dimensions.length; e++) - if (dimensions[e] < 0) - throw new ND4JIllegalStateException("Target dimensions can't have negative values"); + for (int dimension : dimensions) + if (dimension < 0) throw new ND4JIllegalStateException("Target dimensions can't have negative values"); return factory().tear(tensor, dimensions); } - /** * Upper triangle of an array. @@ -5750,50 +5582,34 @@ public class Nd4j { Please refer to the documentation for `tril` for further details. - * @param m - * @param k - * @return + See Also + -------- + tril : lower triangle of an array + + Examples + -------- + >>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1) + array([[ 1, 2, 3], + [ 4, 5, 6], + [ 0, 8, 9], + [ 0, 0, 12]]) + + """ + m = asanyarray(m) + mask = tri(*m.shape[-2:], k=k-1, dtype=bool) + + return where(mask, zeros(1, m.dtype), m) + + * @param m source array + * @param k to zero below the k-th diagonal + * @return copy with elements below the `k`-th diagonal zeroed. */ public static INDArray triu(INDArray m,int k) { - /** - * """ - Upper triangle of an array. - Return a copy of a matrix with the elements below the `k`-th diagonal - zeroed. - - Please refer to the documentation for `tril` for further details. - - See Also - -------- - tril : lower triangle of an array - - Examples - -------- - >>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1) - array([[ 1, 2, 3], - [ 4, 5, 6], - [ 0, 8, 9], - [ 0, 0, 12]]) - - """ - m = asanyarray(m) - mask = tri(*m.shape[-2:], k=k-1, dtype=bool) - - return where(mask, zeros(1, m.dtype), m) - */ - - //INDArray mask = tri(m.size(-2),1); - /** + /* * Find a way to apply choose with an existing condition array. * (This appears to be the select op in libnd4j) */ - /* - Select select = new Select(new INDArray[]{mask,Nd4j.zeros(1),m},new INDArray[]{Nd4j.zerosLike(m)}); - Nd4j.getExecutioner().exec(select); - return select.getOutputArgument(0); - */ - INDArray result = Nd4j.createUninitialized(m.shape()); val op = DynamicCustomOp.builder("triu") @@ -5803,25 +5619,18 @@ public class Nd4j { .build(); Nd4j.getExecutioner().execAndReturn(op); - return result; } - /** - * - * @param n - * @return + * See {@link #tri(int,int,int)} with m = n, k=0. */ public static INDArray tri(int n) { return tri(n,n,0); } /** - * - * @param n - * @param k - * @return + * See {@link #tri(int,int,int)} with m = n. */ public static INDArray tri(int n,int k) { return tri(n,n,k); @@ -5836,24 +5645,16 @@ public class Nd4j { * @param k The sub-diagonal at and below which the array is filled. `k` = 0 is the main diagonal, while `k` < 0 is below it, and `k` > 0 is above. The default is 0. - * @return + * @return array with ones at and below the given diagonal and zeros elsewhere */ public static INDArray tri(int n,int m,int k) { - /* - INDArray mRet = Transforms.greaterThanOrEqual(arange(n),arange(-k,m - k)); - - return mRet; - */ - INDArray ret = Nd4j.createUninitialized(n, m); - val op = DynamicCustomOp.builder("tri") .addIntegerArguments(n, m, k) .addOutputs(ret) .build(); Nd4j.getExecutioner().execAndReturn(op); - return ret; } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 638cd8ac3..202a7e9a7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -7979,6 +7979,14 @@ public class Nd4jTestsC extends BaseNd4jTest { //from [4,4,3] to [2,4,6] then crop to [2,4,5] } + @Test + public void testToFromByteArray() throws IOException { + // simple test to get rid of toByteArray and fromByteArray compiler warnings. + INDArray x = Nd4j.arange(10); + byte[] xb = Nd4j.toByteArray(x); + INDArray y = Nd4j.fromByteArray(xb); + assertEquals(x,y); + } private static INDArray fwd(INDArray input, INDArray W, INDArray b){ INDArray ret = Nd4j.createUninitialized(input.size(0), W.size(1)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java index 1bf709a5f..729c18c77 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java @@ -102,7 +102,7 @@ public class RngTests extends BaseNd4jTest { } @Test - void testRandomBinomial() { + public void testRandomBinomial() { //silly tests. Just increasing the usage for randomBinomial to stop compiler warnings. INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3); assertTrue(x.sum().getDouble(0) > 0.0); //silly test. Just increasing th usage for randomBinomial