From c29c011d1a74f9c0c6f67759a991b5734eb08790 Mon Sep 17 00:00:00 2001 From: Robert Altena Date: Tue, 13 Aug 2019 10:30:45 +0900 Subject: [PATCH] Nd4j refactoring (#109) * refactoring. Signed-off-by: Robert Altena * wip Signed-off-by: Robert Altena * wip Signed-off-by: Robert Altena * wip Signed-off-by: Robert Altena * review feedback Signed-off-by: Robert Altena --- .../java/org/nd4j/linalg/factory/Nd4j.java | 82 ++++--------------- 1 file changed, 18 insertions(+), 64 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index b82ce008a..9a28b9155 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -2222,7 +2222,6 @@ public class Nd4j { * Write an ndarray to a writer * @param writer the writer to write to * @param write the ndarray to write - * @throws IOException */ public static void write(OutputStream writer, INDArray write) throws IOException { DataOutputStream stream = new DataOutputStream(writer); @@ -2230,14 +2229,12 @@ public class Nd4j { stream.close(); } - /** * Convert an ndarray to a byte array * @param arr the array to convert * @return the converted byte array - * @throws IOException */ - public static byte[] toByteArray(INDArray arr) throws IOException { + public static byte[] toByteArray(@NonNull INDArray arr) throws IOException { if (arr.length() * arr.data().getElementSize() > Integer.MAX_VALUE) throw new ND4JIllegalStateException(""); @@ -2252,15 +2249,13 @@ public class Nd4j { * Read an ndarray from a byte array * @param arr the array to read from * @return the deserialized ndarray - * @throws IOException */ - public static INDArray fromByteArray(byte[] arr) throws IOException { + public static INDArray fromByteArray(@NonNull byte[] arr) throws IOException { ByteArrayInputStream bis = new ByteArrayInputStream(arr); INDArray ret = read(bis); return ret; } - /** * Read line via input streams * @@ -2280,7 +2275,6 @@ public class Nd4j { * @param split the split separator * @param charset the charset * @return the deserialized array. - * @throws IOException */ public static INDArray readNumpy(@NonNull DataType dataType, @NonNull InputStream filePath, @NonNull String split, @NonNull Charset charset) throws IOException { BufferedReader reader = new BufferedReader(new InputStreamReader(filePath, charset)); @@ -2449,7 +2443,7 @@ public class Nd4j { } //parse data if (lineNum > 5) { - String[] entries = line.replace("\\],", "").replaceAll("\\]", "").replaceAll("\\[", "").split(sep); + String[] entries = line.replace("\\],", "").replaceAll("]", "").replaceAll("\\[", "").split(sep); if (rank == 0) { try { newArr.addi((format.parse(entries[0])).doubleValue()); @@ -2496,7 +2490,6 @@ public class Nd4j { * @return NDArray */ public static INDArray readTxt(String filePath) { - String sep = ","; File file = new File(filePath); InputStream is = null; try { @@ -2505,13 +2498,7 @@ public class Nd4j { } catch (FileNotFoundException e) { throw new RuntimeException(e); } finally { - if (is != null) { - try { - is.close(); - } catch (IOException e) { - e.printStackTrace(); - } - } + IOUtils.closeQuietly(is); } } @@ -2552,9 +2539,9 @@ public class Nd4j { */ public static INDArray createArrayFromShapeBuffer(DataBuffer data, Pair shapeInfo) { int rank = Shape.rank(shapeInfo.getFirst()); - long offset = Shape.offset(shapeInfo.getFirst()); + // removed offset parameter that called a deprecated method which always returns 0. INDArray result = Nd4j.create(data, toIntArray(rank, Shape.shapeOf(shapeInfo.getFirst())), - toIntArray(rank, Shape.stride(shapeInfo.getFirst())), offset, Shape.order(shapeInfo.getFirst())); + toIntArray(rank, Shape.stride(shapeInfo.getFirst())), 0, Shape.order(shapeInfo.getFirst())); if (data instanceof CompressedDataBuffer) result.markAsCompressed(true); @@ -2566,7 +2553,6 @@ public class Nd4j { * * @param dis the data input stream to read from * @return the ndarray - * @throws IOException */ public static INDArray read(DataInputStream dis) throws IOException { val headerShape = BaseDataBuffer.readHeader(dis); @@ -2597,7 +2583,6 @@ public class Nd4j { * * @param arr the array to write * @param dataOutputStream the data output stream to write to - * @throws IOException */ public static void write(INDArray arr, DataOutputStream dataOutputStream) throws IOException { //BaseDataBuffer.write(...) doesn't know about strides etc, so dup (or equiv. strategy) is necessary here @@ -2614,7 +2599,6 @@ public class Nd4j { * Save an ndarray to the given file * @param arr the array to save * @param saveTo the file to save to - * @throws IOException */ public static void saveBinary(INDArray arr, File saveTo) throws IOException { BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(saveTo)); @@ -2629,7 +2613,6 @@ public class Nd4j { * Read a binary ndarray from the given file * @param read the nd array to read * @return the loaded ndarray - * @throws IOException */ public static INDArray readBinary(File read) throws IOException { BufferedInputStream bis = new BufferedInputStream(new FileInputStream(read)); @@ -2645,21 +2628,9 @@ public class Nd4j { * @param arr the array to clear */ public static void clearNans(INDArray arr) { - //BooleanIndexing.applyWhere(arr, Conditions.isNan(), new Value(Nd4j.EPS_THRESHOLD)); getExecutioner().exec(new ReplaceNans(arr, Nd4j.EPS_THRESHOLD)); } - - /** - * Reverses the passed in matrix such that m[0] becomes m[m.length - 1] etc - * - * @param reverse the matrix to reverse - * @return the reversed matrix - */ - public static INDArray rot(INDArray reverse) { - INDArray ret = INSTANCE.rot(reverse); - return ret; - } - + /** * Reverses the passed in matrix such that m[0] becomes m[m.length - 1] etc * @@ -2667,8 +2638,6 @@ public class Nd4j { * @return the reversed matrix */ public static INDArray reverse(INDArray reverse) { - //INDArray ret = INSTANCE.reverse(reverse); - //logCreationIfNecessary(ret); return Nd4j.getExecutioner().exec(new OldReverse(reverse)); } @@ -2682,8 +2651,7 @@ public class Nd4j { * @return the 1D range vector */ public static INDArray arange(double begin, double end, double step) { - INDArray ret = INSTANCE.arange(begin, end, step); - return ret; + return INSTANCE.arange(begin, end, step); } /** @@ -2693,8 +2661,7 @@ public class Nd4j { * See {@link #arange(double, double, double)} with step size 1. */ public static INDArray arange(double begin, double end) { - INDArray ret = INSTANCE.arange(begin, end, 1); - return ret; + return INSTANCE.arange(begin, end, 1); } /** @@ -2717,27 +2684,6 @@ public class Nd4j { INSTANCE.copy(a, b); } - /** - * Creates a new matrix where the values of the given vector are the diagonal values of - * the matrix if a vector is passed in, if a matrix is returns the kth diagonal - * in the matrix - * - * @param x the diagonal values - * @param k the kth diagonal to get - * @return new matrix - */ - public static INDArray diag(INDArray x, int k) { - INDArray ret; - if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) { - ret = Nd4j.create(new long[]{x.length(), x.length()}); - Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret})); - } else { - ret = Nd4j.createUninitialized(new long[]{Math.min(x.size(0), x.size(1))}); - Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret)); - } - return ret; - } - /** * Creates a new matrix where the values of the given vector are the diagonal values of * the matrix if a vector is passed in, if a matrix is returns the kth diagonal @@ -2747,7 +2693,15 @@ public class Nd4j { * @return new matrix */ public static INDArray diag(INDArray x) { - return diag(x, 0); + INDArray ret; + if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) { + ret = Nd4j.create(x.dataType(), x.length(), x.length()); + Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret})); + } else { + ret = Nd4j.createUninitialized(x.dataType(), Math.min(x.size(0), x.size(1))); + Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret)); + } + return ret; } /**