From f8615e0ef0c71c2a9493c20cc1dea925ef39951c Mon Sep 17 00:00:00 2001 From: Robert Altena Date: Wed, 7 Aug 2019 19:31:48 +0900 Subject: [PATCH] nd4j refactoring. (#103) Restored writeTxt per slack DM. --- .../linalg/api/ndarray/NdArrayJSONReader.java | 144 -------- .../linalg/api/ndarray/NdArrayJSONWriter.java | 54 --- .../java/org/nd4j/linalg/factory/Nd4j.java | 330 +++++------------- 3 files changed, 85 insertions(+), 443 deletions(-) delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/NdArrayJSONReader.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/NdArrayJSONWriter.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/NdArrayJSONReader.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/NdArrayJSONReader.java deleted file mode 100644 index 74f2ad152..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/NdArrayJSONReader.java +++ /dev/null @@ -1,144 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ndarray; - -import org.apache.commons.io.FileUtils; -import org.apache.commons.io.LineIterator; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.File; -import java.io.IOException; - -/** - * Created by susaneraly on 6/16/16. - */ -@Deprecated -public class NdArrayJSONReader { - - public INDArray read(File jsonFile) { - INDArray result = this.loadNative(jsonFile); - if (result == null) { - //Must write support for parsing/normal json parsing - which will be inefficient - this.loadNonNative(jsonFile); - } - return result; - } - - private INDArray loadNative(File jsonFile) { - /* - We could dump an ndarray to a file with the tostring (since that is valid json) and use put/get to parse it as json - - But here we leverage our information of the tostring method to be more efficient - With our current toString format we use tads along dimension (rank-1,rank-2) to write to the array in two dimensional chunks at a time. - This is more efficient than setting each value at a time with putScalar. - This also means we can read the file one line at a time instead of loading the whole thing into memory - - Future work involves enhancing the write json method to provide more features to make the load more efficient - */ - int lineNum = 0; - int rowNum = 0; - int tensorNum = 0; - char theOrder = 'c'; - int[] theShape = {1, 1}; - int rank = 0; - double[][] subsetArr = {{0.0, 0.0}, {0.0, 0.0}}; - INDArray newArr = Nd4j.zeros(2, 2); - try { - LineIterator it = FileUtils.lineIterator(jsonFile); - try { - while (it.hasNext()) { - String line = it.nextLine(); - lineNum++; - line = line.replaceAll("\\s", ""); - if (line.equals("") || line.equals("}")) - continue; - // is it from dl4j? - if (lineNum == 2) { - String[] lineArr = line.split(":"); - String fileSource = lineArr[1].replaceAll("\\W", ""); - if (!fileSource.equals("dl4j")) - return null; - } - // parse ordering - if (lineNum == 3) { - String[] lineArr = line.split(":"); - theOrder = lineArr[1].replace("\\W", "").charAt(0); - continue; - } - // parse shape - if (lineNum == 4) { - String[] lineArr = line.split(":"); - String dropJsonComma = lineArr[1].split("]")[0]; - String[] shapeString = dropJsonComma.replace("[", "").split(","); - rank = shapeString.length; - theShape = new int[rank]; - for (int i = 0; i < rank; i++) { - try { - theShape[i] = Integer.parseInt(shapeString[i]); - } catch (NumberFormatException nfe) { - } ; - } - subsetArr = new double[theShape[rank - 2]][theShape[rank - 1]]; - newArr = Nd4j.zeros(theShape, theOrder); - continue; - } - //parse data - if (lineNum > 5) { - String[] entries = - line.replace("\\],", "").replaceAll("\\[", "").replaceAll("\\]", "").split(","); - for (int i = 0; i < theShape[rank - 1]; i++) { - try { - subsetArr[rowNum][i] = Double.parseDouble(entries[i]); - } catch (NumberFormatException nfe) { - } - } - rowNum++; - if (rowNum == theShape[rank - 2]) { - INDArray subTensor = Nd4j.create(subsetArr); - newArr.tensorAlongDimension(tensorNum, rank - 1, rank - 2).addi(subTensor); - rowNum = 0; - tensorNum++; - } - } - } - } finally { - LineIterator.closeQuietly(it); - } - } catch (IOException e) { - throw new RuntimeException("Error reading input", e); - } - return newArr; - } - - private INDArray loadNonNative(File jsonFile) { - /* WIP - JSONTokener tokener = new JSONTokener(new FileReader("test.json")); - JSONObject obj = new JSONObject(tokener); - JSONArray objArr = obj.optJSONArray("shape"); - int rank = objArr.length(); - int[] theShape = new int[rank]; - int rows = 1; - for (int i = 0; i < rank; ++i) { - theShape[i] = objArr.optInt(i); - if (i != objArr.length() - 1) - rows *= theShape[i]; - } - */ - System.out.println("API_Error: Current support only for files written from dl4j"); - return null; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/NdArrayJSONWriter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/NdArrayJSONWriter.java deleted file mode 100644 index 8d4cf86f8..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/NdArrayJSONWriter.java +++ /dev/null @@ -1,54 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ndarray; - -import org.apache.commons.io.FileUtils; - -import java.io.File; -import java.io.IOException; - -/** - * Created by susaneraly on 6/18/16. - */ -@Deprecated -public class NdArrayJSONWriter { - private NdArrayJSONWriter() {} - - /** - * - * @param thisnD - * @param filePath - */ - public static void write(INDArray thisnD, String filePath) { - //TO DO: Add precision support in toString - //TO DO: Write to file one line at time - String lineOne = "{\n"; - String lineTwo = "\"filefrom\": \"dl4j\",\n"; - String lineThree = "\"ordering\": \"" + thisnD.ordering() + "\",\n"; - String lineFour = "\"shape\":\t" + java.util.Arrays.toString(thisnD.shape()) + ",\n"; - String lineFive = "\"data\":\n"; - String fileData = thisnD.toString(); - String fileEnd = "\n}\n"; - - String fileBegin = lineOne + lineTwo + lineThree + lineFour + lineFive; - try { - FileUtils.writeStringToFile(new File(filePath), fileBegin + fileData + fileEnd); - } catch (IOException e) { - throw new RuntimeException("Error writing output", e); - } - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 1317d67ed..b82ce008a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -75,7 +75,6 @@ import org.nd4j.linalg.compression.CompressedDataBuffer; import org.nd4j.linalg.convolution.ConvolutionInstance; import org.nd4j.linalg.convolution.DefaultConvolutionInstance; import org.nd4j.linalg.env.EnvironmentalAction; -import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JUnknownDataTypeException; import org.nd4j.linalg.factory.Nd4jBackend.NoAvailableBackendException; @@ -1073,7 +1072,7 @@ public class Nd4j { return ret; } - protected static Indexer getIndexerByType(Pointer pointer, DataType dataType) { + private static Indexer getIndexerByType(Pointer pointer, DataType dataType) { switch (dataType) { case UINT64: case LONG: @@ -1113,40 +1112,7 @@ public class Nd4j { * @return the created buffer */ public static DataBuffer createBuffer(@NonNull Pointer pointer, long length, @NonNull DataType dataType) { - Pointer nPointer = null; - //TODO: remove dupplicate code. - switch (dataType) { - case LONG: - nPointer = new LongPointer(pointer); - break; - case INT: - nPointer = new IntPointer(pointer); - break; - case SHORT: - nPointer = new ShortPointer(pointer); - break; - case BYTE: - nPointer = new BytePointer(pointer); - break; - case UBYTE: - nPointer = new BytePointer(pointer); - break; - case BOOL: - nPointer = new BooleanPointer(pointer); - break; - case FLOAT: - nPointer = new FloatPointer(pointer); - break; - case HALF: - nPointer = new ShortPointer(pointer); - break; - case DOUBLE: - nPointer = new DoublePointer(pointer); - break; - default: - throw new UnsupportedOperationException("Unsupported data type: " + dataType); - } - + Pointer nPointer = getPointer(pointer, dataType); return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, dataType, length, getIndexerByType(nPointer, dataType)); } @@ -1161,7 +1127,12 @@ public class Nd4j { * @return the created buffer */ public static DataBuffer createBuffer(@NonNull Pointer pointer, @NonNull Pointer devicePointer, long length, @NonNull DataType dataType) { - Pointer nPointer = null; + Pointer nPointer = getPointer(pointer, dataType); + return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, devicePointer, dataType, length, getIndexerByType(nPointer, dataType)); + } + + private static Pointer getPointer(@NonNull Pointer pointer, @NonNull DataType dataType ){ + Pointer nPointer; switch (dataType) { case UINT64: case LONG: @@ -1198,7 +1169,7 @@ public class Nd4j { throw new UnsupportedOperationException("Unsupported data type: " + dataType); } - return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, devicePointer, dataType, length, getIndexerByType(nPointer, dataType)); + return nPointer; } /** @@ -1208,9 +1179,8 @@ public class Nd4j { * @return the created buffer */ public static DataBuffer createBuffer(float[] data, long offset) { - val ndata = Arrays.copyOfRange(data, (int) offset, data.length); - DataBuffer ret = createTypedBuffer(ndata, DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace()); - return ret; + return createTypedBuffer(Arrays.copyOfRange(data, (int) offset, data.length), + DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace()); } /** @@ -1220,9 +1190,8 @@ public class Nd4j { * @return the created buffer */ public static DataBuffer createBuffer(double[] data, long offset) { - val ndata = Arrays.copyOfRange(data, (int) offset, data.length); - DataBuffer ret = createTypedBuffer(ndata, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); - return ret; + return createTypedBuffer(Arrays.copyOfRange(data, (int) offset, data.length), + DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); } /** @@ -1294,7 +1263,7 @@ public class Nd4j { * * @param shape the shape of the buffer to create * @param type the opType to create - * @return + * @return the created buffer. */ public static DataBuffer createBufferDetached(int[] shape, DataType type) { return createBufferDetachedImpl( ArrayUtil.prodLong(shape), type); @@ -1354,7 +1323,7 @@ public class Nd4j { * @param buffer the buffer to create from * @param type the opType of buffer to create * @param length the length of the buffer - * @return + * @return the created buffer */ public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length) { switch (type) { @@ -1407,33 +1376,27 @@ public class Nd4j { * @return the created buffer */ public static DataBuffer createBuffer(long[] data) { - DataBuffer ret; - ret = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createLong(data) : DATA_BUFFER_FACTORY_INSTANCE.createLong(data, Nd4j.getMemoryManager().getCurrentWorkspace()); - return ret; + return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createLong(data) : DATA_BUFFER_FACTORY_INSTANCE.createLong(data, Nd4j.getMemoryManager().getCurrentWorkspace()); } /** * Create a buffer equal of length prod(shape). This method is NOT affected by workspaces * - * @param data - * @return + * @param data the shape of the buffer to create + * @return the created buffer */ public static DataBuffer createBufferDetached(int[] data) { - DataBuffer ret; - ret = DATA_BUFFER_FACTORY_INSTANCE.createInt(data); - return ret; + return DATA_BUFFER_FACTORY_INSTANCE.createInt(data); } /** * Create a buffer equal of length prod(shape). This method is NOT affected by workspaces * - * @param data - * @return + * @param data the shape of the buffer to create + * @return the created buffer */ public static DataBuffer createBufferDetached(long[] data) { - DataBuffer ret; - ret = DATA_BUFFER_FACTORY_INSTANCE.createLong(data); - return ret; + return DATA_BUFFER_FACTORY_INSTANCE.createLong(data); } /** @@ -1464,9 +1427,7 @@ public class Nd4j { * See {@link #createBuffer(DataType dataType, long length, boolean initialize) with default datatype. */ public static DataBuffer createBuffer(long length, boolean initialize) { - DataBuffer ret = createBuffer(Nd4j.dataType(), length, initialize); - - return ret; + return createBuffer(Nd4j.dataType(), length, initialize); } /** @@ -1524,6 +1485,11 @@ public class Nd4j { return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createDouble(data) : DATA_BUFFER_FACTORY_INSTANCE.createDouble(data, Nd4j.getMemoryManager().getCurrentWorkspace()); } + // refactoring of duplicate code. + private static DataBuffer getDataBuffer(int length, DataType dataType){ + return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); + } + /** * Create a buffer based on the data of the underlying java array with the specified type.. * @param data underlying java array @@ -1531,16 +1497,16 @@ public class Nd4j { * @return created buffer, */ public static DataBuffer createTypedBuffer(double[] data, DataType dataType) { - val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); + DataBuffer buffer = getDataBuffer(data.length, dataType); buffer.setData(data); return buffer; } /** - * See {@link #createTypedBuffer(float[], DataType)} + * See {@link #createTypedBuffer(double[], DataType)} */ public static DataBuffer createTypedBuffer(float[] data, DataType dataType) { - val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); + DataBuffer buffer = getDataBuffer(data.length, dataType); buffer.setData(data); return buffer; } @@ -1549,7 +1515,7 @@ public class Nd4j { * See {@link #createTypedBuffer(float[], DataType)} */ public static DataBuffer createTypedBuffer(int[] data, DataType dataType) { - val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); + DataBuffer buffer = getDataBuffer(data.length, dataType); buffer.setData(data); return buffer; } @@ -1558,7 +1524,7 @@ public class Nd4j { * See {@link #createTypedBuffer(float[], DataType)} */ public static DataBuffer createTypedBuffer(long[] data, DataType dataType) { - val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); + DataBuffer buffer = getDataBuffer(data.length, dataType); buffer.setData(data); return buffer; } @@ -1567,7 +1533,7 @@ public class Nd4j { * See {@link #createTypedBuffer(float[], DataType)} */ public static DataBuffer createTypedBuffer(short[] data, DataType dataType) { - val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); + DataBuffer buffer = getDataBuffer(data.length, dataType); buffer.setData(data); return buffer; } @@ -1576,7 +1542,7 @@ public class Nd4j { * See {@link #createTypedBuffer(float[], DataType)} */ public static DataBuffer createTypedBuffer(byte[] data, DataType dataType) { - val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); + DataBuffer buffer = getDataBuffer(data.length, dataType); buffer.setData(data); return buffer; } @@ -1585,11 +1551,17 @@ public class Nd4j { * See {@link #createTypedBuffer(float[], DataType)} */ public static DataBuffer createTypedBuffer(boolean[] data, DataType dataType) { - val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); + DataBuffer buffer = getDataBuffer(data.length, dataType); buffer.setData(data); return buffer; } + + // refactoring of duplicate code. + private static DataBuffer getDataBuffer(int length, DataType dataType, MemoryWorkspace workspace){ + return workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false, workspace); + } + /** * Create a buffer based on the data of the underlying java array, specified type and workspace * @param data underlying java array @@ -1598,61 +1570,18 @@ public class Nd4j { * @return created buffer, */ public static DataBuffer createTypedBuffer(double[] data, DataType dataType, MemoryWorkspace workspace) { - val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace); + //val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace); + DataBuffer buffer = getDataBuffer(data.length, dataType, workspace); buffer.setData(data); return buffer; } /** - * See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)} + * See {@link #createTypedBuffer(double[], DataType, MemoryWorkspace)} */ public static DataBuffer createTypedBuffer(float[] data, DataType dataType, MemoryWorkspace workspace) { - val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace); - buffer.setData(data); - return buffer; - } - - /** - * See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)} - */ - public static DataBuffer createTypedBuffer(int[] data, DataType dataType, MemoryWorkspace workspace) { - val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace); - buffer.setData(data); - return buffer; - } - - /** - * See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)} - */ - public static DataBuffer createTypedBuffer(long[] data, DataType dataType, MemoryWorkspace workspace) { - val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace); - buffer.setData(data); - return buffer; - } - - /** - * See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)} - */ - public static DataBuffer createTypedBuffer(short[] data, DataType dataType, MemoryWorkspace workspace) { - val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace); - buffer.setData(data); - return buffer; - } - - /** - * See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)} - */ - public static DataBuffer createTypedBuffer(byte[] data, DataType dataType, MemoryWorkspace workspace) { - val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace); - buffer.setData(data); - return buffer; - } - - /** - * See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)} - */ - public static DataBuffer createTypedBuffer(boolean[] data, DataType dataType, MemoryWorkspace workspace) { - val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace); + //val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace); + DataBuffer buffer = getDataBuffer(data.length, dataType, workspace); buffer.setData(data); return buffer; } @@ -1661,7 +1590,7 @@ public class Nd4j { * Create am uninitialized buffer based on the data of the underlying java array and specified type. * @param data underlying java array * @param dataType specified type. - * @return + * @return the created buffer. */ public static DataBuffer createTypedBufferDetached(double[] data, DataType dataType) { val buffer = DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false); @@ -1731,14 +1660,6 @@ public class Nd4j { INSTANCE = factory; } - /** - * Set the factory instance for sparse INDArray creation. - * @param factory new INDArray factory - */ - public static void setSparseFactory(NDArrayFactory factory) { - SPARSE_INSTANCE = factory; - } - /** * Returns the ordering of the ndarrays * @@ -1810,15 +1731,6 @@ public class Nd4j { return SPARSE_BLAS_WRAPPER_INSTANCE; } - /** - * Sets the global blas wrapper - * - * @param factory - */ - public static void setBlasWrapper(BlasWrapper factory) { - BLAS_WRAPPER_INSTANCE = factory; - } - /** * Sort an ndarray along a particular dimension.
* Note that the input array is modified in-place. @@ -1899,7 +1811,7 @@ public class Nd4j { * * @param ndarray array to sort * @param ascending true for ascending, false for descending - * @return + * @return the sorted ndarray */ public static INDArray sort(INDArray ndarray, boolean ascending) { return getNDArrayFactory().sort(ndarray, !ascending); @@ -1931,20 +1843,18 @@ public class Nd4j { * @param in 2d array to sort * @param colIdx The column to sort on * @param ascending true if smallest-to-largest; false if largest-to-smallest - * @return + * @return the sorted ndarray */ + @SuppressWarnings("Duplicates") public static INDArray sortRows(final INDArray in, final int colIdx, final boolean ascending) { if (in.rank() != 2) throw new IllegalArgumentException("Cannot sort rows on non-2d matrix"); if (colIdx < 0 || colIdx >= in.columns()) throw new IllegalArgumentException("Cannot sort on values in column " + colIdx + ", nCols=" + in.columns()); - if (in.rows() > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - INDArray out = Nd4j.create(in.dataType(), in.shape()); - int nRows = (int) in.rows(); - ArrayList list = new ArrayList(nRows); + int nRows = in.rows(); + ArrayList list = new ArrayList<>(nRows); for (int i = 0; i < nRows; i++) list.add(i); Collections.sort(list, new Comparator() { @@ -1976,19 +1886,17 @@ public class Nd4j { * @param in 2d array to sort * @param rowIdx The row to sort on * @param ascending true if smallest-to-largest; false if largest-to-smallest - * @return + * @return the sorted array. */ + @SuppressWarnings("Duplicates") public static INDArray sortColumns(final INDArray in, final int rowIdx, final boolean ascending) { if (in.rank() != 2) throw new IllegalArgumentException("Cannot sort columns on non-2d matrix"); if (rowIdx < 0 || rowIdx >= in.rows()) throw new IllegalArgumentException("Cannot sort on values in row " + rowIdx + ", nRows=" + in.rows()); - if (in.columns() > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - INDArray out = Nd4j.create(in.shape()); - int nCols = (int) in.columns(); + int nCols = in.columns(); ArrayList list = new ArrayList<>(nCols); for (int i = 0; i < nCols; i++) list.add(i); @@ -2041,7 +1949,7 @@ public class Nd4j { if (dtype.isIntType()) { long upper = lower + num * step; - return linspaceWithCustomOpByRange((long) lower, upper, num, (long) step, dtype); + return linspaceWithCustomOpByRange( lower, upper, num, step, dtype); } else if (dtype.isFPType()) { return Nd4j.getExecutioner().exec(new Linspace((double) lower, num, (double)step, dtype)); } @@ -2062,7 +1970,6 @@ public class Nd4j { return linspace(lower, upper, num, Nd4j.dataType()); } - /** * Generate a linearly spaced vector * @@ -2077,7 +1984,7 @@ public class Nd4j { return Nd4j.scalar(dtype, lower); } if (dtype.isIntType()) { - return linspaceWithCustomOp((long)lower, (long)upper, (int)num, dtype); + return linspaceWithCustomOp(lower, upper, (int)num, dtype); } else if (dtype.isFPType()) { return linspace((double) lower, (double)upper, (int) num, dtype); } @@ -2182,8 +2089,7 @@ public class Nd4j { * these ndarrays */ public static INDArray toFlattened(Collection matrices) { - INDArray ret = INSTANCE.toFlattened(matrices); - return ret; + return INSTANCE.toFlattened(matrices); } /** @@ -2194,27 +2100,7 @@ public class Nd4j { * these ndarrays */ public static INDArray toFlattened(char order, Collection matrices) { - INDArray ret = INSTANCE.toFlattened(order, matrices); - return ret; - } - - /** - * Create a long row vector of all of the given ndarrays - * @param matrices the matrices to create the flattened ndarray for - * @return the flattened representation of - * these ndarrays - */ - public static INDArray toFlattened(int length, Iterator... matrices) { - INDArray ret = INSTANCE.toFlattened(length, matrices); - return ret; - } - - /** - * Returns a column vector where each entry is the nth bilinear - * product of the nth slices of the two tensors. - */ - public static INDArray bilinearProducts(INDArray curr, INDArray in) { - return INSTANCE.bilinearProducts(curr, in); + return INSTANCE.toFlattened(order, matrices); } /** @@ -2243,32 +2129,28 @@ public class Nd4j { * Create the identity ndarray * * @param n the number for the identity - * @return + * @return the identity array */ public static INDArray eye(long n) { - INDArray ret = INSTANCE.eye(n); - return ret; + return INSTANCE.eye(n); } /** * Rotate a matrix 90 degrees * * @param toRotate the matrix to rotate - * @return the rotated matrix */ public static void rot90(INDArray toRotate) { INSTANCE.rot90(toRotate); - } /** * Write NDArray to a text file * - * @param filePath + * @param filePath path to write to * @param split the split separator, defaults to "," - * @deprecated custom col separators are no longer supported; uses "," * @param precision digits after the decimal point - * @deprecated Precision is no longer used. + * @deprecated Precision is no longer used. Split is no longer used. * Defaults to scientific notation with 18 digits after the decimal * Use {@link #writeTxt(INDArray, String)} */ @@ -2279,10 +2161,10 @@ public class Nd4j { /** * Write NDArray to a text file * - * @param write - * @param filePath - * @param precision - * @deprecated Precision is no longer used. + * @param write array to write + * @param filePath path to write to + * @param precision Precision is no longer used. + * @deprecated * Defaults to scientific notation with 18 digits after the decimal * Use {@link #writeTxt(INDArray, String)} */ @@ -2293,9 +2175,9 @@ public class Nd4j { /** * Write NDArray to a text file * - * @param write - * @param filePath - * @param split + * @param write array to write + * @param filePath path to write to + * @param split the split separator, defaults to "," * @deprecated custom col and higher dimension separators are no longer supported; uses "," * Use {@link #writeTxt(INDArray, String)} */ @@ -2311,69 +2193,27 @@ public class Nd4j { */ public static void writeTxt(INDArray write, String filePath) { try { - String toWrite = writeStringForArray(write, "0.000000000000000000E0"); - FileUtils.writeStringToFile(new File(filePath), toWrite); + String toWrite = writeStringForArray(write); + FileUtils.writeStringToFile(new File(filePath), toWrite, (String)null, false); } catch (IOException e) { throw new RuntimeException("Error writing output", e); } } - /** - * @deprecated custom col separators are no longer supported; uses "," - * @deprecated precision can no longer be specified. The array is written in scientific notation. - * @see #writeTxtString(INDArray, OutputStream) - */ - public static void writeTxtString(INDArray write, OutputStream os, String split, int precision) { - writeTxtString(write,os); - } - - /** - * @deprecated precision can no longer be specified. The array is written in scientific notation. - * @see #writeTxtString(INDArray, OutputStream) - */ - @Deprecated - public static void writeTxtString(INDArray write, OutputStream os, int precision) { - writeTxtString(write,os); - } - - /** - * @deprecated column separator can longer be specified; Uses "," - * @see #writeTxtString(INDArray, OutputStream) - */ - @Deprecated - public static void writeTxtString(INDArray write, OutputStream os, String split) { - writeTxtString(write, os); - } - - /** - * Write ndarray as text to output stream - * @param write Array to write - * @param os stream to write too. - */ - public static void writeTxtString(INDArray write, OutputStream os) { - try { - // default format is "0.000000000000000000E0" - String toWrite = writeStringForArray(write, "0.000000000000000000E0"); - os.write(toWrite.getBytes()); - } catch (IOException e) { - throw new RuntimeException("Error writing output", e); - } - } - - private static String writeStringForArray(INDArray write, String format) { + private static String writeStringForArray(INDArray write) { if(write.isView() || !Shape.hasDefaultStridesForShape(write)) write = write.dup(); - if (format.isEmpty()) format = "0.000000000000000000E0"; - String lineOne = "{\n"; - String lineTwo = "\"filefrom\": \"dl4j\",\n"; - String lineThree = "\"ordering\": \"" + write.ordering() + "\",\n"; - String lineFour = "\"shape\":\t" + java.util.Arrays.toString(write.shape()) + ",\n"; - String lineFive = "\"data\":\n"; - String fileData = new NDArrayStrings(",", format).format(write, false); - String fileEnd = "\n}\n"; - String fileBegin = lineOne + lineTwo + lineThree + lineFour + lineFive; - String fileContents = fileBegin + fileData + fileEnd; - return fileContents; + String format = "0.000000000000000000E0"; + + return new StringBuilder() + .append("{\n") + .append("\"filefrom\": \"dl4j\",\n") + .append( "\"ordering\": \"").append(write.ordering()).append("\",\n") + .append("\"shape\":\t").append( java.util.Arrays.toString(write.shape())).append(",\n") + .append("\"data\":\n") + .append(new NDArrayStrings(",", format).format(write, false)) + .append("\n}\n") + .toString(); }