diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/NdArrayJSONReader.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/NdArrayJSONReader.java
deleted file mode 100644
index 74f2ad152..000000000
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/NdArrayJSONReader.java
+++ /dev/null
@@ -1,144 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.nd4j.linalg.api.ndarray;
-
-import org.apache.commons.io.FileUtils;
-import org.apache.commons.io.LineIterator;
-import org.nd4j.linalg.factory.Nd4j;
-
-import java.io.File;
-import java.io.IOException;
-
-/**
- * Created by susaneraly on 6/16/16.
- */
-@Deprecated
-public class NdArrayJSONReader {
-
- public INDArray read(File jsonFile) {
- INDArray result = this.loadNative(jsonFile);
- if (result == null) {
- //Must write support for parsing/normal json parsing - which will be inefficient
- this.loadNonNative(jsonFile);
- }
- return result;
- }
-
- private INDArray loadNative(File jsonFile) {
- /*
- We could dump an ndarray to a file with the tostring (since that is valid json) and use put/get to parse it as json
-
- But here we leverage our information of the tostring method to be more efficient
- With our current toString format we use tads along dimension (rank-1,rank-2) to write to the array in two dimensional chunks at a time.
- This is more efficient than setting each value at a time with putScalar.
- This also means we can read the file one line at a time instead of loading the whole thing into memory
-
- Future work involves enhancing the write json method to provide more features to make the load more efficient
- */
- int lineNum = 0;
- int rowNum = 0;
- int tensorNum = 0;
- char theOrder = 'c';
- int[] theShape = {1, 1};
- int rank = 0;
- double[][] subsetArr = {{0.0, 0.0}, {0.0, 0.0}};
- INDArray newArr = Nd4j.zeros(2, 2);
- try {
- LineIterator it = FileUtils.lineIterator(jsonFile);
- try {
- while (it.hasNext()) {
- String line = it.nextLine();
- lineNum++;
- line = line.replaceAll("\\s", "");
- if (line.equals("") || line.equals("}"))
- continue;
- // is it from dl4j?
- if (lineNum == 2) {
- String[] lineArr = line.split(":");
- String fileSource = lineArr[1].replaceAll("\\W", "");
- if (!fileSource.equals("dl4j"))
- return null;
- }
- // parse ordering
- if (lineNum == 3) {
- String[] lineArr = line.split(":");
- theOrder = lineArr[1].replace("\\W", "").charAt(0);
- continue;
- }
- // parse shape
- if (lineNum == 4) {
- String[] lineArr = line.split(":");
- String dropJsonComma = lineArr[1].split("]")[0];
- String[] shapeString = dropJsonComma.replace("[", "").split(",");
- rank = shapeString.length;
- theShape = new int[rank];
- for (int i = 0; i < rank; i++) {
- try {
- theShape[i] = Integer.parseInt(shapeString[i]);
- } catch (NumberFormatException nfe) {
- } ;
- }
- subsetArr = new double[theShape[rank - 2]][theShape[rank - 1]];
- newArr = Nd4j.zeros(theShape, theOrder);
- continue;
- }
- //parse data
- if (lineNum > 5) {
- String[] entries =
- line.replace("\\],", "").replaceAll("\\[", "").replaceAll("\\]", "").split(",");
- for (int i = 0; i < theShape[rank - 1]; i++) {
- try {
- subsetArr[rowNum][i] = Double.parseDouble(entries[i]);
- } catch (NumberFormatException nfe) {
- }
- }
- rowNum++;
- if (rowNum == theShape[rank - 2]) {
- INDArray subTensor = Nd4j.create(subsetArr);
- newArr.tensorAlongDimension(tensorNum, rank - 1, rank - 2).addi(subTensor);
- rowNum = 0;
- tensorNum++;
- }
- }
- }
- } finally {
- LineIterator.closeQuietly(it);
- }
- } catch (IOException e) {
- throw new RuntimeException("Error reading input", e);
- }
- return newArr;
- }
-
- private INDArray loadNonNative(File jsonFile) {
- /* WIP
- JSONTokener tokener = new JSONTokener(new FileReader("test.json"));
- JSONObject obj = new JSONObject(tokener);
- JSONArray objArr = obj.optJSONArray("shape");
- int rank = objArr.length();
- int[] theShape = new int[rank];
- int rows = 1;
- for (int i = 0; i < rank; ++i) {
- theShape[i] = objArr.optInt(i);
- if (i != objArr.length() - 1)
- rows *= theShape[i];
- }
- */
- System.out.println("API_Error: Current support only for files written from dl4j");
- return null;
- }
-}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/NdArrayJSONWriter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/NdArrayJSONWriter.java
deleted file mode 100644
index 8d4cf86f8..000000000
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/NdArrayJSONWriter.java
+++ /dev/null
@@ -1,54 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.nd4j.linalg.api.ndarray;
-
-import org.apache.commons.io.FileUtils;
-
-import java.io.File;
-import java.io.IOException;
-
-/**
- * Created by susaneraly on 6/18/16.
- */
-@Deprecated
-public class NdArrayJSONWriter {
- private NdArrayJSONWriter() {}
-
- /**
- *
- * @param thisnD
- * @param filePath
- */
- public static void write(INDArray thisnD, String filePath) {
- //TO DO: Add precision support in toString
- //TO DO: Write to file one line at time
- String lineOne = "{\n";
- String lineTwo = "\"filefrom\": \"dl4j\",\n";
- String lineThree = "\"ordering\": \"" + thisnD.ordering() + "\",\n";
- String lineFour = "\"shape\":\t" + java.util.Arrays.toString(thisnD.shape()) + ",\n";
- String lineFive = "\"data\":\n";
- String fileData = thisnD.toString();
- String fileEnd = "\n}\n";
-
- String fileBegin = lineOne + lineTwo + lineThree + lineFour + lineFive;
- try {
- FileUtils.writeStringToFile(new File(filePath), fileBegin + fileData + fileEnd);
- } catch (IOException e) {
- throw new RuntimeException("Error writing output", e);
- }
- }
-}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java
index 1317d67ed..b82ce008a 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java
@@ -75,7 +75,6 @@ import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.convolution.ConvolutionInstance;
import org.nd4j.linalg.convolution.DefaultConvolutionInstance;
import org.nd4j.linalg.env.EnvironmentalAction;
-import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4JUnknownDataTypeException;
import org.nd4j.linalg.factory.Nd4jBackend.NoAvailableBackendException;
@@ -1073,7 +1072,7 @@ public class Nd4j {
return ret;
}
- protected static Indexer getIndexerByType(Pointer pointer, DataType dataType) {
+ private static Indexer getIndexerByType(Pointer pointer, DataType dataType) {
switch (dataType) {
case UINT64:
case LONG:
@@ -1113,40 +1112,7 @@ public class Nd4j {
* @return the created buffer
*/
public static DataBuffer createBuffer(@NonNull Pointer pointer, long length, @NonNull DataType dataType) {
- Pointer nPointer = null;
- //TODO: remove dupplicate code.
- switch (dataType) {
- case LONG:
- nPointer = new LongPointer(pointer);
- break;
- case INT:
- nPointer = new IntPointer(pointer);
- break;
- case SHORT:
- nPointer = new ShortPointer(pointer);
- break;
- case BYTE:
- nPointer = new BytePointer(pointer);
- break;
- case UBYTE:
- nPointer = new BytePointer(pointer);
- break;
- case BOOL:
- nPointer = new BooleanPointer(pointer);
- break;
- case FLOAT:
- nPointer = new FloatPointer(pointer);
- break;
- case HALF:
- nPointer = new ShortPointer(pointer);
- break;
- case DOUBLE:
- nPointer = new DoublePointer(pointer);
- break;
- default:
- throw new UnsupportedOperationException("Unsupported data type: " + dataType);
- }
-
+ Pointer nPointer = getPointer(pointer, dataType);
return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, dataType, length, getIndexerByType(nPointer, dataType));
}
@@ -1161,7 +1127,12 @@ public class Nd4j {
* @return the created buffer
*/
public static DataBuffer createBuffer(@NonNull Pointer pointer, @NonNull Pointer devicePointer, long length, @NonNull DataType dataType) {
- Pointer nPointer = null;
+ Pointer nPointer = getPointer(pointer, dataType);
+ return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, devicePointer, dataType, length, getIndexerByType(nPointer, dataType));
+ }
+
+ private static Pointer getPointer(@NonNull Pointer pointer, @NonNull DataType dataType ){
+ Pointer nPointer;
switch (dataType) {
case UINT64:
case LONG:
@@ -1198,7 +1169,7 @@ public class Nd4j {
throw new UnsupportedOperationException("Unsupported data type: " + dataType);
}
- return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, devicePointer, dataType, length, getIndexerByType(nPointer, dataType));
+ return nPointer;
}
/**
@@ -1208,9 +1179,8 @@ public class Nd4j {
* @return the created buffer
*/
public static DataBuffer createBuffer(float[] data, long offset) {
- val ndata = Arrays.copyOfRange(data, (int) offset, data.length);
- DataBuffer ret = createTypedBuffer(ndata, DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
- return ret;
+ return createTypedBuffer(Arrays.copyOfRange(data, (int) offset, data.length),
+ DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
}
/**
@@ -1220,9 +1190,8 @@ public class Nd4j {
* @return the created buffer
*/
public static DataBuffer createBuffer(double[] data, long offset) {
- val ndata = Arrays.copyOfRange(data, (int) offset, data.length);
- DataBuffer ret = createTypedBuffer(ndata, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
- return ret;
+ return createTypedBuffer(Arrays.copyOfRange(data, (int) offset, data.length),
+ DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
}
/**
@@ -1294,7 +1263,7 @@ public class Nd4j {
*
* @param shape the shape of the buffer to create
* @param type the opType to create
- * @return
+ * @return the created buffer.
*/
public static DataBuffer createBufferDetached(int[] shape, DataType type) {
return createBufferDetachedImpl( ArrayUtil.prodLong(shape), type);
@@ -1354,7 +1323,7 @@ public class Nd4j {
* @param buffer the buffer to create from
* @param type the opType of buffer to create
* @param length the length of the buffer
- * @return
+ * @return the created buffer
*/
public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length) {
switch (type) {
@@ -1407,33 +1376,27 @@ public class Nd4j {
* @return the created buffer
*/
public static DataBuffer createBuffer(long[] data) {
- DataBuffer ret;
- ret = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createLong(data) : DATA_BUFFER_FACTORY_INSTANCE.createLong(data, Nd4j.getMemoryManager().getCurrentWorkspace());
- return ret;
+ return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createLong(data) : DATA_BUFFER_FACTORY_INSTANCE.createLong(data, Nd4j.getMemoryManager().getCurrentWorkspace());
}
/**
* Create a buffer equal of length prod(shape). This method is NOT affected by workspaces
*
- * @param data
- * @return
+ * @param data the shape of the buffer to create
+ * @return the created buffer
*/
public static DataBuffer createBufferDetached(int[] data) {
- DataBuffer ret;
- ret = DATA_BUFFER_FACTORY_INSTANCE.createInt(data);
- return ret;
+ return DATA_BUFFER_FACTORY_INSTANCE.createInt(data);
}
/**
* Create a buffer equal of length prod(shape). This method is NOT affected by workspaces
*
- * @param data
- * @return
+ * @param data the shape of the buffer to create
+ * @return the created buffer
*/
public static DataBuffer createBufferDetached(long[] data) {
- DataBuffer ret;
- ret = DATA_BUFFER_FACTORY_INSTANCE.createLong(data);
- return ret;
+ return DATA_BUFFER_FACTORY_INSTANCE.createLong(data);
}
/**
@@ -1464,9 +1427,7 @@ public class Nd4j {
* See {@link #createBuffer(DataType dataType, long length, boolean initialize) with default datatype.
*/
public static DataBuffer createBuffer(long length, boolean initialize) {
- DataBuffer ret = createBuffer(Nd4j.dataType(), length, initialize);
-
- return ret;
+ return createBuffer(Nd4j.dataType(), length, initialize);
}
/**
@@ -1524,6 +1485,11 @@ public class Nd4j {
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createDouble(data) : DATA_BUFFER_FACTORY_INSTANCE.createDouble(data, Nd4j.getMemoryManager().getCurrentWorkspace());
}
+ // refactoring of duplicate code.
+ private static DataBuffer getDataBuffer(int length, DataType dataType){
+ return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
+ }
+
/**
* Create a buffer based on the data of the underlying java array with the specified type..
* @param data underlying java array
@@ -1531,16 +1497,16 @@ public class Nd4j {
* @return created buffer,
*/
public static DataBuffer createTypedBuffer(double[] data, DataType dataType) {
- val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
+ DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data);
return buffer;
}
/**
- * See {@link #createTypedBuffer(float[], DataType)}
+ * See {@link #createTypedBuffer(double[], DataType)}
*/
public static DataBuffer createTypedBuffer(float[] data, DataType dataType) {
- val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
+ DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data);
return buffer;
}
@@ -1549,7 +1515,7 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)}
*/
public static DataBuffer createTypedBuffer(int[] data, DataType dataType) {
- val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
+ DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data);
return buffer;
}
@@ -1558,7 +1524,7 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)}
*/
public static DataBuffer createTypedBuffer(long[] data, DataType dataType) {
- val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
+ DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data);
return buffer;
}
@@ -1567,7 +1533,7 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)}
*/
public static DataBuffer createTypedBuffer(short[] data, DataType dataType) {
- val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
+ DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data);
return buffer;
}
@@ -1576,7 +1542,7 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)}
*/
public static DataBuffer createTypedBuffer(byte[] data, DataType dataType) {
- val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
+ DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data);
return buffer;
}
@@ -1585,11 +1551,17 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)}
*/
public static DataBuffer createTypedBuffer(boolean[] data, DataType dataType) {
- val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
+ DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data);
return buffer;
}
+
+ // refactoring of duplicate code.
+ private static DataBuffer getDataBuffer(int length, DataType dataType, MemoryWorkspace workspace){
+ return workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false, workspace);
+ }
+
/**
* Create a buffer based on the data of the underlying java array, specified type and workspace
* @param data underlying java array
@@ -1598,61 +1570,18 @@ public class Nd4j {
* @return created buffer,
*/
public static DataBuffer createTypedBuffer(double[] data, DataType dataType, MemoryWorkspace workspace) {
- val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
+ //val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
+ DataBuffer buffer = getDataBuffer(data.length, dataType, workspace);
buffer.setData(data);
return buffer;
}
/**
- * See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
+ * See {@link #createTypedBuffer(double[], DataType, MemoryWorkspace)}
*/
public static DataBuffer createTypedBuffer(float[] data, DataType dataType, MemoryWorkspace workspace) {
- val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
- buffer.setData(data);
- return buffer;
- }
-
- /**
- * See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
- */
- public static DataBuffer createTypedBuffer(int[] data, DataType dataType, MemoryWorkspace workspace) {
- val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
- buffer.setData(data);
- return buffer;
- }
-
- /**
- * See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
- */
- public static DataBuffer createTypedBuffer(long[] data, DataType dataType, MemoryWorkspace workspace) {
- val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
- buffer.setData(data);
- return buffer;
- }
-
- /**
- * See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
- */
- public static DataBuffer createTypedBuffer(short[] data, DataType dataType, MemoryWorkspace workspace) {
- val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
- buffer.setData(data);
- return buffer;
- }
-
- /**
- * See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
- */
- public static DataBuffer createTypedBuffer(byte[] data, DataType dataType, MemoryWorkspace workspace) {
- val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
- buffer.setData(data);
- return buffer;
- }
-
- /**
- * See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
- */
- public static DataBuffer createTypedBuffer(boolean[] data, DataType dataType, MemoryWorkspace workspace) {
- val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
+ //val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
+ DataBuffer buffer = getDataBuffer(data.length, dataType, workspace);
buffer.setData(data);
return buffer;
}
@@ -1661,7 +1590,7 @@ public class Nd4j {
* Create am uninitialized buffer based on the data of the underlying java array and specified type.
* @param data underlying java array
* @param dataType specified type.
- * @return
+ * @return the created buffer.
*/
public static DataBuffer createTypedBufferDetached(double[] data, DataType dataType) {
val buffer = DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false);
@@ -1731,14 +1660,6 @@ public class Nd4j {
INSTANCE = factory;
}
- /**
- * Set the factory instance for sparse INDArray creation.
- * @param factory new INDArray factory
- */
- public static void setSparseFactory(NDArrayFactory factory) {
- SPARSE_INSTANCE = factory;
- }
-
/**
* Returns the ordering of the ndarrays
*
@@ -1810,15 +1731,6 @@ public class Nd4j {
return SPARSE_BLAS_WRAPPER_INSTANCE;
}
- /**
- * Sets the global blas wrapper
- *
- * @param factory
- */
- public static void setBlasWrapper(BlasWrapper factory) {
- BLAS_WRAPPER_INSTANCE = factory;
- }
-
/**
* Sort an ndarray along a particular dimension.
* Note that the input array is modified in-place.
@@ -1899,7 +1811,7 @@ public class Nd4j {
*
* @param ndarray array to sort
* @param ascending true for ascending, false for descending
- * @return
+ * @return the sorted ndarray
*/
public static INDArray sort(INDArray ndarray, boolean ascending) {
return getNDArrayFactory().sort(ndarray, !ascending);
@@ -1931,20 +1843,18 @@ public class Nd4j {
* @param in 2d array to sort
* @param colIdx The column to sort on
* @param ascending true if smallest-to-largest; false if largest-to-smallest
- * @return
+ * @return the sorted ndarray
*/
+ @SuppressWarnings("Duplicates")
public static INDArray sortRows(final INDArray in, final int colIdx, final boolean ascending) {
if (in.rank() != 2)
throw new IllegalArgumentException("Cannot sort rows on non-2d matrix");
if (colIdx < 0 || colIdx >= in.columns())
throw new IllegalArgumentException("Cannot sort on values in column " + colIdx + ", nCols=" + in.columns());
- if (in.rows() > Integer.MAX_VALUE)
- throw new ND4JArraySizeException();
-
INDArray out = Nd4j.create(in.dataType(), in.shape());
- int nRows = (int) in.rows();
- ArrayList list = new ArrayList(nRows);
+ int nRows = in.rows();
+ ArrayList list = new ArrayList<>(nRows);
for (int i = 0; i < nRows; i++)
list.add(i);
Collections.sort(list, new Comparator() {
@@ -1976,19 +1886,17 @@ public class Nd4j {
* @param in 2d array to sort
* @param rowIdx The row to sort on
* @param ascending true if smallest-to-largest; false if largest-to-smallest
- * @return
+ * @return the sorted array.
*/
+ @SuppressWarnings("Duplicates")
public static INDArray sortColumns(final INDArray in, final int rowIdx, final boolean ascending) {
if (in.rank() != 2)
throw new IllegalArgumentException("Cannot sort columns on non-2d matrix");
if (rowIdx < 0 || rowIdx >= in.rows())
throw new IllegalArgumentException("Cannot sort on values in row " + rowIdx + ", nRows=" + in.rows());
- if (in.columns() > Integer.MAX_VALUE)
- throw new ND4JArraySizeException();
-
INDArray out = Nd4j.create(in.shape());
- int nCols = (int) in.columns();
+ int nCols = in.columns();
ArrayList list = new ArrayList<>(nCols);
for (int i = 0; i < nCols; i++)
list.add(i);
@@ -2041,7 +1949,7 @@ public class Nd4j {
if (dtype.isIntType()) {
long upper = lower + num * step;
- return linspaceWithCustomOpByRange((long) lower, upper, num, (long) step, dtype);
+ return linspaceWithCustomOpByRange( lower, upper, num, step, dtype);
} else if (dtype.isFPType()) {
return Nd4j.getExecutioner().exec(new Linspace((double) lower, num, (double)step, dtype));
}
@@ -2062,7 +1970,6 @@ public class Nd4j {
return linspace(lower, upper, num, Nd4j.dataType());
}
-
/**
* Generate a linearly spaced vector
*
@@ -2077,7 +1984,7 @@ public class Nd4j {
return Nd4j.scalar(dtype, lower);
}
if (dtype.isIntType()) {
- return linspaceWithCustomOp((long)lower, (long)upper, (int)num, dtype);
+ return linspaceWithCustomOp(lower, upper, (int)num, dtype);
} else if (dtype.isFPType()) {
return linspace((double) lower, (double)upper, (int) num, dtype);
}
@@ -2182,8 +2089,7 @@ public class Nd4j {
* these ndarrays
*/
public static INDArray toFlattened(Collection matrices) {
- INDArray ret = INSTANCE.toFlattened(matrices);
- return ret;
+ return INSTANCE.toFlattened(matrices);
}
/**
@@ -2194,27 +2100,7 @@ public class Nd4j {
* these ndarrays
*/
public static INDArray toFlattened(char order, Collection matrices) {
- INDArray ret = INSTANCE.toFlattened(order, matrices);
- return ret;
- }
-
- /**
- * Create a long row vector of all of the given ndarrays
- * @param matrices the matrices to create the flattened ndarray for
- * @return the flattened representation of
- * these ndarrays
- */
- public static INDArray toFlattened(int length, Iterator extends INDArray>... matrices) {
- INDArray ret = INSTANCE.toFlattened(length, matrices);
- return ret;
- }
-
- /**
- * Returns a column vector where each entry is the nth bilinear
- * product of the nth slices of the two tensors.
- */
- public static INDArray bilinearProducts(INDArray curr, INDArray in) {
- return INSTANCE.bilinearProducts(curr, in);
+ return INSTANCE.toFlattened(order, matrices);
}
/**
@@ -2243,32 +2129,28 @@ public class Nd4j {
* Create the identity ndarray
*
* @param n the number for the identity
- * @return
+ * @return the identity array
*/
public static INDArray eye(long n) {
- INDArray ret = INSTANCE.eye(n);
- return ret;
+ return INSTANCE.eye(n);
}
/**
* Rotate a matrix 90 degrees
*
* @param toRotate the matrix to rotate
- * @return the rotated matrix
*/
public static void rot90(INDArray toRotate) {
INSTANCE.rot90(toRotate);
-
}
/**
* Write NDArray to a text file
*
- * @param filePath
+ * @param filePath path to write to
* @param split the split separator, defaults to ","
- * @deprecated custom col separators are no longer supported; uses ","
* @param precision digits after the decimal point
- * @deprecated Precision is no longer used.
+ * @deprecated Precision is no longer used. Split is no longer used.
* Defaults to scientific notation with 18 digits after the decimal
* Use {@link #writeTxt(INDArray, String)}
*/
@@ -2279,10 +2161,10 @@ public class Nd4j {
/**
* Write NDArray to a text file
*
- * @param write
- * @param filePath
- * @param precision
- * @deprecated Precision is no longer used.
+ * @param write array to write
+ * @param filePath path to write to
+ * @param precision Precision is no longer used.
+ * @deprecated
* Defaults to scientific notation with 18 digits after the decimal
* Use {@link #writeTxt(INDArray, String)}
*/
@@ -2293,9 +2175,9 @@ public class Nd4j {
/**
* Write NDArray to a text file
*
- * @param write
- * @param filePath
- * @param split
+ * @param write array to write
+ * @param filePath path to write to
+ * @param split the split separator, defaults to ","
* @deprecated custom col and higher dimension separators are no longer supported; uses ","
* Use {@link #writeTxt(INDArray, String)}
*/
@@ -2311,69 +2193,27 @@ public class Nd4j {
*/
public static void writeTxt(INDArray write, String filePath) {
try {
- String toWrite = writeStringForArray(write, "0.000000000000000000E0");
- FileUtils.writeStringToFile(new File(filePath), toWrite);
+ String toWrite = writeStringForArray(write);
+ FileUtils.writeStringToFile(new File(filePath), toWrite, (String)null, false);
} catch (IOException e) {
throw new RuntimeException("Error writing output", e);
}
}
- /**
- * @deprecated custom col separators are no longer supported; uses ","
- * @deprecated precision can no longer be specified. The array is written in scientific notation.
- * @see #writeTxtString(INDArray, OutputStream)
- */
- public static void writeTxtString(INDArray write, OutputStream os, String split, int precision) {
- writeTxtString(write,os);
- }
-
- /**
- * @deprecated precision can no longer be specified. The array is written in scientific notation.
- * @see #writeTxtString(INDArray, OutputStream)
- */
- @Deprecated
- public static void writeTxtString(INDArray write, OutputStream os, int precision) {
- writeTxtString(write,os);
- }
-
- /**
- * @deprecated column separator can longer be specified; Uses ","
- * @see #writeTxtString(INDArray, OutputStream)
- */
- @Deprecated
- public static void writeTxtString(INDArray write, OutputStream os, String split) {
- writeTxtString(write, os);
- }
-
- /**
- * Write ndarray as text to output stream
- * @param write Array to write
- * @param os stream to write too.
- */
- public static void writeTxtString(INDArray write, OutputStream os) {
- try {
- // default format is "0.000000000000000000E0"
- String toWrite = writeStringForArray(write, "0.000000000000000000E0");
- os.write(toWrite.getBytes());
- } catch (IOException e) {
- throw new RuntimeException("Error writing output", e);
- }
- }
-
- private static String writeStringForArray(INDArray write, String format) {
+ private static String writeStringForArray(INDArray write) {
if(write.isView() || !Shape.hasDefaultStridesForShape(write))
write = write.dup();
- if (format.isEmpty()) format = "0.000000000000000000E0";
- String lineOne = "{\n";
- String lineTwo = "\"filefrom\": \"dl4j\",\n";
- String lineThree = "\"ordering\": \"" + write.ordering() + "\",\n";
- String lineFour = "\"shape\":\t" + java.util.Arrays.toString(write.shape()) + ",\n";
- String lineFive = "\"data\":\n";
- String fileData = new NDArrayStrings(",", format).format(write, false);
- String fileEnd = "\n}\n";
- String fileBegin = lineOne + lineTwo + lineThree + lineFour + lineFive;
- String fileContents = fileBegin + fileData + fileEnd;
- return fileContents;
+ String format = "0.000000000000000000E0";
+
+ return new StringBuilder()
+ .append("{\n")
+ .append("\"filefrom\": \"dl4j\",\n")
+ .append( "\"ordering\": \"").append(write.ordering()).append("\",\n")
+ .append("\"shape\":\t").append( java.util.Arrays.toString(write.shape())).append(",\n")
+ .append("\"data\":\n")
+ .append(new NDArrayStrings(",", format).format(write, false))
+ .append("\n}\n")
+ .toString();
}