Nd4j refactoring (last one!) (#123)

* fix: IOException no longer thrown by read().

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* refactoring

* last refactorings

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-08-16 15:39:11 +09:00 committed by Alex Black
parent 5c908886b0
commit 7fbc4b0933
1 changed files with 67 additions and 72 deletions

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -5673,7 +5673,6 @@ public class Nd4j {
public static INDArray[] where(INDArray condition, INDArray x, INDArray y){
Preconditions.checkState((x == null && y == null) || (x != null && y != null), "Both X and Y must be" +
"null, or neither must be null");
INDArray out;
DynamicCustomOp.DynamicCustomOpsBuilder op = DynamicCustomOp.builder("where_np");
List<LongShapeDescriptor> outShapes;
if(x == null){
@ -5681,6 +5680,7 @@ public class Nd4j {
op.addInputs(condition);
} else {
if(!x.equalShapes(y) || !x.equalShapes(condition)){
//noinspection ConstantConditions
Preconditions.throwStateEx("Shapes must be equal: condition=%s, x=%s, y=%s", condition.shape(), x.shape(), y.shape());
}
op.addInputs(condition, x, y);
@ -5713,6 +5713,7 @@ public class Nd4j {
* @param file the file to write to
* @throws IOException if an error occurs when writing the file
*/
@SuppressWarnings("WeakerAccess")
public static void writeAsNumpy(INDArray arr, File file) throws IOException {
writeAsNumpy(arr, new FileOutputStream(file));
}
@ -5723,6 +5724,7 @@ public class Nd4j {
* @param arr the array to convert
* @return a pointer to the numpy struct
*/
@SuppressWarnings("WeakerAccess")
public static Pointer convertToNumpy(INDArray arr) {
return INSTANCE.convertToNumpy(arr);
}
@ -5732,8 +5734,8 @@ public class Nd4j {
* Writes an array to an output stream
* @param arr the array to write
* @param writeTo the output stream to write to
* @throws IOException
*/
@SuppressWarnings("WeakerAccess")
public static void writeAsNumpy(INDArray arr, OutputStream writeTo) throws IOException {
try(BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(writeTo)) {
Pointer asNumpy = convertToNumpy(arr);
@ -5746,7 +5748,6 @@ public class Nd4j {
bufferedOutputStream.flush();
}
}
@ -5759,6 +5760,7 @@ public class Nd4j {
* numpy pointer
*/
@SuppressWarnings("WeakerAccess")
public static INDArray createFromNpyPointer(Pointer pointer) {
return INSTANCE.createFromNpyPointer(pointer);
}
@ -5784,8 +5786,8 @@ public class Nd4j {
* Create a numpy array based on the passed in input stream
* @param is the input stream to read
* @return the loaded ndarray
* @throws IOException
*/
@SuppressWarnings("unused")
public static INDArray createNpyFromInputStream(InputStream is) throws IOException {
byte[] content = IOUtils.toByteArray(is);
return createNpyFromByteArray(content);
@ -5814,7 +5816,6 @@ public class Nd4j {
* @return the {@link INDArray} as a byte array
* with the numpy format.
* For more on the format, see: https://docs.scipy.org/doc/numpy-1.14.0/neps/npy-format.html
* @throws IOException
*/
public static byte[] toNpyByteArray(INDArray input) {
try {
@ -5907,7 +5908,7 @@ public class Nd4j {
val bytes = new byte[prod];
val sb = bb.order(_order).asReadOnlyBuffer();
for (int e = 0; e < prod; e++)
bytes[e] = (byte) sb.get(e + sb.position());
bytes[e] = sb.get(e + sb.position());
return Nd4j.create(bytes, shapeOf, stridesOf, ordering, DataType.BYTE);
}
@ -5951,7 +5952,7 @@ public class Nd4j {
/**
* This method returns maximal allowed number of threads for Nd4j.
* If value wasn't set in advance, max(1, availableProcessor) will be returned
* @return
* @return maximal allowed number of threads
*/
public static int numThreads() {
val v = numThreads.get();
@ -5963,7 +5964,7 @@ public class Nd4j {
/**
* This method sets maximal allowed number of threads for Nd4j
* @param numthreads
* @param numthreads maximal allowed number of threads
*/
public static void setNumThreads(int numthreads) {
numThreads.set(numthreads);
@ -5979,6 +5980,7 @@ public class Nd4j {
public static INDArray scalar(@NonNull String string) {
//noinspection RedundantArrayCreation
return create(Collections.singletonList(string), new long[0]);
}
@ -5998,7 +6000,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 1D INDArray with DOUBLE data type
*/
public static INDArray createFromArray(double... array) {
@ -6011,7 +6013,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 1D INDArray with FLOAT data type
*/
public static INDArray createFromArray(float... array) {
@ -6024,7 +6026,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 1D INDArray with INT32 data type
*/
public static INDArray createFromArray(int... array) {
@ -6037,7 +6039,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 1D INDArray with INT16 data type
*/
public static INDArray createFromArray(short... array) {
@ -6050,7 +6052,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 1D INDArray with INT8 data type
*/
public static INDArray createFromArray(byte... array) {
@ -6063,7 +6065,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 1D INDArray with INT64 data type
*/
public static INDArray createFromArray(long... array) {
@ -6076,7 +6078,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 1D INDArray with BOOL data type
*/
public static INDArray createFromArray(boolean... array) {
@ -6091,7 +6093,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 2D INDArray with DOUBLE data type
*/
public static INDArray createFromArray(double[][] array) {
@ -6105,7 +6107,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 2D INDArray with FLOAT data type
*/
public static INDArray createFromArray(float[][] array) {
@ -6119,7 +6121,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 2D INDArray with INT64 data type
*/
public static INDArray createFromArray(long[][] array) {
@ -6133,7 +6135,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 2D INDArray with INT32 data type
*/
public static INDArray createFromArray(int[][] array) {
@ -6147,7 +6149,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 2D INDArray with INT16 data type
*/
public static INDArray createFromArray(short[][] array) {
@ -6161,7 +6163,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 2D INDArray with INT8 data type
*/
public static INDArray createFromArray(byte[][] array) {
@ -6175,7 +6177,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 2D INDArray with BOOL data type
*/
public static INDArray createFromArray(boolean[][] array) {
@ -6192,7 +6194,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 3D INDArray with DOUBLE data type
*/
public static INDArray createFromArray(double[][][] array) {
@ -6206,7 +6208,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 3D INDArray with FLOAT data type
*/
public static INDArray createFromArray(float[][][] array) {
@ -6220,7 +6222,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 3D INDArray with INT64 data type
*/
public static INDArray createFromArray(long[][][] array) {
@ -6235,7 +6237,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 3D INDArray with INT32 data type
*/
public static INDArray createFromArray(int[][][] array) {
@ -6250,7 +6252,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 3D INDArray with INT16 data type
*/
public static INDArray createFromArray(short[][][] array) {
@ -6264,7 +6266,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 3D INDArray with INT8 data type
*/
public static INDArray createFromArray(byte[][][] array) {
@ -6278,7 +6280,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 3D INDArray with BOOL data type
*/
public static INDArray createFromArray(boolean[][][] array) {
@ -6294,7 +6296,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 4D INDArray with DOUBLE data type
*/
public static INDArray createFromArray(double[][][][] array) {
@ -6308,7 +6310,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 4D INDArray with FLOAT data type
*/
public static INDArray createFromArray(float[][][][] array) {
@ -6322,7 +6324,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 4D INDArray with INT64 data type
*/
public static INDArray createFromArray(long[][][][] array) {
@ -6336,7 +6338,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 4D INDArray with INT32 data type
*/
public static INDArray createFromArray(int[][][][] array) {
@ -6350,7 +6352,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 4D INDArray with INT16 data type
*/
public static INDArray createFromArray(short[][][][] array) {
@ -6364,7 +6366,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 4D INDArray with INT8 data type
*/
public static INDArray createFromArray(byte[][][][] array) {
@ -6378,7 +6380,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 4D INDArray with BOOL data type
*/
public static INDArray createFromArray(boolean[][][][] array) {
@ -6390,7 +6392,6 @@ public class Nd4j {
return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.BOOL);
}
public static synchronized DeallocatorService getDeallocatorService() {
if (deallocatorService == null)
deallocatorService = new DeallocatorService();
@ -6402,7 +6403,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 1D INDArray with DOUBLE data type
*/
public static INDArray createFromArray(Double[] array) {
@ -6411,7 +6412,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 1D INDArray with FLOAT data type
*/
public static INDArray createFromArray(Float[] array) {
@ -6420,7 +6421,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 1D INDArray with INT32 data type
*/
public static INDArray createFromArray(Integer[] array) {
@ -6429,7 +6430,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 1D INDArray with INT16 data type
*/
public static INDArray createFromArray(Short[] array) {
@ -6438,7 +6439,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 1D INDArray with INT8 data type
*/
public static INDArray createFromArray(Byte[] array) {
@ -6447,7 +6448,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 1D INDArray with INT64 data type
*/
public static INDArray createFromArray(Long[] array) {
@ -6456,7 +6457,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 1D INDArray with BOOL data type
*/
public static INDArray createFromArray(Boolean[] array) {
@ -6467,7 +6468,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 2D INDArray with DOUBLE data type
*/
public static INDArray createFromArray(Double[][] array) {
@ -6476,7 +6477,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 2D INDArray with FLOAT data type
*/
public static INDArray createFromArray(Float[][] array) {
@ -6485,7 +6486,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 2D INDArray with INT32 data type
*/
public static INDArray createFromArray(Integer[][] array) {
@ -6494,7 +6495,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 2D INDArray with INT16 data type
*/
public static INDArray createFromArray(Short[][] array) {
@ -6503,7 +6504,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 2D INDArray with INT8 data type
*/
public static INDArray createFromArray(Byte[][] array) {
@ -6512,7 +6513,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 2D INDArray with INT64 data type
*/
public static INDArray createFromArray(Long[][] array) {
@ -6521,7 +6522,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 2D INDArray with BOOL data type
*/
public static INDArray createFromArray(Boolean[][] array) {
@ -6532,7 +6533,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 3D INDArray with DOUBLE data type
*/
public static INDArray createFromArray(Double[][][] array) {
@ -6541,7 +6542,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 3D INDArray with FLOAT data type
*/
public static INDArray createFromArray(Float[][][] array) {
@ -6550,7 +6551,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 3D INDArray with INT32 data type
*/
public static INDArray createFromArray(Integer[][][] array) {
@ -6559,7 +6560,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 3D INDArray with INT16 data type
*/
public static INDArray createFromArray(Short[][][] array) {
@ -6568,7 +6569,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 3D INDArray with INT8 data type
*/
public static INDArray createFromArray(Byte[][][] array) {
@ -6577,7 +6578,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 3D INDArray with INT64 data type
*/
public static INDArray createFromArray(Long[][][] array) {
@ -6586,7 +6587,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 3D INDArray with BOOL data type
*/
public static INDArray createFromArray(Boolean[][][] array) {
@ -6597,7 +6598,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 4D INDArray with DOUBLE data type
*/
public static INDArray createFromArray(Double[][][][] array) {
@ -6606,7 +6607,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 4D INDArray with FLOAT data type
*/
public static INDArray createFromArray(Float[][][][] array) {
@ -6615,7 +6616,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 4D INDArray with INT32 data type
*/
public static INDArray createFromArray(Integer[][][][] array) {
@ -6624,7 +6625,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 4D INDArray with INT16 data type
*/
public static INDArray createFromArray(Short[][][][] array) {
@ -6633,7 +6634,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 4D INDArray with INT8 data type
*/
public static INDArray createFromArray(Byte[][][][] array) {
@ -6642,7 +6643,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 4D INDArray with INT64 data type
*/
public static INDArray createFromArray(Long[][][][] array) {
@ -6651,7 +6652,7 @@ public class Nd4j {
/**
* This method creates INDArray from provided jvm array
* @param array
* @param array jvm array
* @return 4D INDArray with BOOL data type
*/
public static INDArray createFromArray(Boolean[][][][] array) {
@ -6692,12 +6693,6 @@ public class Nd4j {
/**
* This method applies ScatterUpdate op
*
* @param op
* @param array
* @param indices
* @param updates
* @param axis
*/
@Deprecated
public static void scatterUpdate(ScatterUpdate.UpdateOp op, @NonNull INDArray array, @NonNull INDArray indices, @NonNull INDArray updates, int... axis) {