nd4j refactoring. (#103)

Restored writeTxt per slack DM.
master
Robert Altena 2019-08-07 19:31:48 +09:00 committed by GitHub
parent edb71bf46f
commit f8615e0ef0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 443 deletions

View File

@ -1,144 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ndarray;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.LineIterator;
import org.nd4j.linalg.factory.Nd4j;
import java.io.File;
import java.io.IOException;
/**
* Created by susaneraly on 6/16/16.
*/
@Deprecated
public class NdArrayJSONReader {
public INDArray read(File jsonFile) {
INDArray result = this.loadNative(jsonFile);
if (result == null) {
//Must write support for parsing/normal json parsing - which will be inefficient
this.loadNonNative(jsonFile);
}
return result;
}
private INDArray loadNative(File jsonFile) {
/*
We could dump an ndarray to a file with the tostring (since that is valid json) and use put/get to parse it as json
But here we leverage our information of the tostring method to be more efficient
With our current toString format we use tads along dimension (rank-1,rank-2) to write to the array in two dimensional chunks at a time.
This is more efficient than setting each value at a time with putScalar.
This also means we can read the file one line at a time instead of loading the whole thing into memory
Future work involves enhancing the write json method to provide more features to make the load more efficient
*/
int lineNum = 0;
int rowNum = 0;
int tensorNum = 0;
char theOrder = 'c';
int[] theShape = {1, 1};
int rank = 0;
double[][] subsetArr = {{0.0, 0.0}, {0.0, 0.0}};
INDArray newArr = Nd4j.zeros(2, 2);
try {
LineIterator it = FileUtils.lineIterator(jsonFile);
try {
while (it.hasNext()) {
String line = it.nextLine();
lineNum++;
line = line.replaceAll("\\s", "");
if (line.equals("") || line.equals("}"))
continue;
// is it from dl4j?
if (lineNum == 2) {
String[] lineArr = line.split(":");
String fileSource = lineArr[1].replaceAll("\\W", "");
if (!fileSource.equals("dl4j"))
return null;
}
// parse ordering
if (lineNum == 3) {
String[] lineArr = line.split(":");
theOrder = lineArr[1].replace("\\W", "").charAt(0);
continue;
}
// parse shape
if (lineNum == 4) {
String[] lineArr = line.split(":");
String dropJsonComma = lineArr[1].split("]")[0];
String[] shapeString = dropJsonComma.replace("[", "").split(",");
rank = shapeString.length;
theShape = new int[rank];
for (int i = 0; i < rank; i++) {
try {
theShape[i] = Integer.parseInt(shapeString[i]);
} catch (NumberFormatException nfe) {
} ;
}
subsetArr = new double[theShape[rank - 2]][theShape[rank - 1]];
newArr = Nd4j.zeros(theShape, theOrder);
continue;
}
//parse data
if (lineNum > 5) {
String[] entries =
line.replace("\\],", "").replaceAll("\\[", "").replaceAll("\\]", "").split(",");
for (int i = 0; i < theShape[rank - 1]; i++) {
try {
subsetArr[rowNum][i] = Double.parseDouble(entries[i]);
} catch (NumberFormatException nfe) {
}
}
rowNum++;
if (rowNum == theShape[rank - 2]) {
INDArray subTensor = Nd4j.create(subsetArr);
newArr.tensorAlongDimension(tensorNum, rank - 1, rank - 2).addi(subTensor);
rowNum = 0;
tensorNum++;
}
}
}
} finally {
LineIterator.closeQuietly(it);
}
} catch (IOException e) {
throw new RuntimeException("Error reading input", e);
}
return newArr;
}
private INDArray loadNonNative(File jsonFile) {
/* WIP
JSONTokener tokener = new JSONTokener(new FileReader("test.json"));
JSONObject obj = new JSONObject(tokener);
JSONArray objArr = obj.optJSONArray("shape");
int rank = objArr.length();
int[] theShape = new int[rank];
int rows = 1;
for (int i = 0; i < rank; ++i) {
theShape[i] = objArr.optInt(i);
if (i != objArr.length() - 1)
rows *= theShape[i];
}
*/
System.out.println("API_Error: Current support only for files written from dl4j");
return null;
}
}

View File

@ -1,54 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ndarray;
import org.apache.commons.io.FileUtils;
import java.io.File;
import java.io.IOException;
/**
* Created by susaneraly on 6/18/16.
*/
@Deprecated
public class NdArrayJSONWriter {
private NdArrayJSONWriter() {}
/**
*
* @param thisnD
* @param filePath
*/
public static void write(INDArray thisnD, String filePath) {
//TO DO: Add precision support in toString
//TO DO: Write to file one line at time
String lineOne = "{\n";
String lineTwo = "\"filefrom\": \"dl4j\",\n";
String lineThree = "\"ordering\": \"" + thisnD.ordering() + "\",\n";
String lineFour = "\"shape\":\t" + java.util.Arrays.toString(thisnD.shape()) + ",\n";
String lineFive = "\"data\":\n";
String fileData = thisnD.toString();
String fileEnd = "\n}\n";
String fileBegin = lineOne + lineTwo + lineThree + lineFour + lineFive;
try {
FileUtils.writeStringToFile(new File(filePath), fileBegin + fileData + fileEnd);
} catch (IOException e) {
throw new RuntimeException("Error writing output", e);
}
}
}

View File

@ -75,7 +75,6 @@ import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.convolution.ConvolutionInstance;
import org.nd4j.linalg.convolution.DefaultConvolutionInstance;
import org.nd4j.linalg.env.EnvironmentalAction;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4JUnknownDataTypeException;
import org.nd4j.linalg.factory.Nd4jBackend.NoAvailableBackendException;
@ -1073,7 +1072,7 @@ public class Nd4j {
return ret;
}
protected static Indexer getIndexerByType(Pointer pointer, DataType dataType) {
private static Indexer getIndexerByType(Pointer pointer, DataType dataType) {
switch (dataType) {
case UINT64:
case LONG:
@ -1113,40 +1112,7 @@ public class Nd4j {
* @return the created buffer
*/
public static DataBuffer createBuffer(@NonNull Pointer pointer, long length, @NonNull DataType dataType) {
Pointer nPointer = null;
//TODO: remove dupplicate code.
switch (dataType) {
case LONG:
nPointer = new LongPointer(pointer);
break;
case INT:
nPointer = new IntPointer(pointer);
break;
case SHORT:
nPointer = new ShortPointer(pointer);
break;
case BYTE:
nPointer = new BytePointer(pointer);
break;
case UBYTE:
nPointer = new BytePointer(pointer);
break;
case BOOL:
nPointer = new BooleanPointer(pointer);
break;
case FLOAT:
nPointer = new FloatPointer(pointer);
break;
case HALF:
nPointer = new ShortPointer(pointer);
break;
case DOUBLE:
nPointer = new DoublePointer(pointer);
break;
default:
throw new UnsupportedOperationException("Unsupported data type: " + dataType);
}
Pointer nPointer = getPointer(pointer, dataType);
return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, dataType, length, getIndexerByType(nPointer, dataType));
}
@ -1161,7 +1127,12 @@ public class Nd4j {
* @return the created buffer
*/
public static DataBuffer createBuffer(@NonNull Pointer pointer, @NonNull Pointer devicePointer, long length, @NonNull DataType dataType) {
Pointer nPointer = null;
Pointer nPointer = getPointer(pointer, dataType);
return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, devicePointer, dataType, length, getIndexerByType(nPointer, dataType));
}
private static Pointer getPointer(@NonNull Pointer pointer, @NonNull DataType dataType ){
Pointer nPointer;
switch (dataType) {
case UINT64:
case LONG:
@ -1198,7 +1169,7 @@ public class Nd4j {
throw new UnsupportedOperationException("Unsupported data type: " + dataType);
}
return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, devicePointer, dataType, length, getIndexerByType(nPointer, dataType));
return nPointer;
}
/**
@ -1208,9 +1179,8 @@ public class Nd4j {
* @return the created buffer
*/
public static DataBuffer createBuffer(float[] data, long offset) {
val ndata = Arrays.copyOfRange(data, (int) offset, data.length);
DataBuffer ret = createTypedBuffer(ndata, DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
return ret;
return createTypedBuffer(Arrays.copyOfRange(data, (int) offset, data.length),
DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
}
/**
@ -1220,9 +1190,8 @@ public class Nd4j {
* @return the created buffer
*/
public static DataBuffer createBuffer(double[] data, long offset) {
val ndata = Arrays.copyOfRange(data, (int) offset, data.length);
DataBuffer ret = createTypedBuffer(ndata, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
return ret;
return createTypedBuffer(Arrays.copyOfRange(data, (int) offset, data.length),
DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
}
/**
@ -1294,7 +1263,7 @@ public class Nd4j {
*
* @param shape the shape of the buffer to create
* @param type the opType to create
* @return
* @return the created buffer.
*/
public static DataBuffer createBufferDetached(int[] shape, DataType type) {
return createBufferDetachedImpl( ArrayUtil.prodLong(shape), type);
@ -1354,7 +1323,7 @@ public class Nd4j {
* @param buffer the buffer to create from
* @param type the opType of buffer to create
* @param length the length of the buffer
* @return
* @return the created buffer
*/
public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length) {
switch (type) {
@ -1407,33 +1376,27 @@ public class Nd4j {
* @return the created buffer
*/
public static DataBuffer createBuffer(long[] data) {
DataBuffer ret;
ret = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createLong(data) : DATA_BUFFER_FACTORY_INSTANCE.createLong(data, Nd4j.getMemoryManager().getCurrentWorkspace());
return ret;
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createLong(data) : DATA_BUFFER_FACTORY_INSTANCE.createLong(data, Nd4j.getMemoryManager().getCurrentWorkspace());
}
/**
* Create a buffer equal of length prod(shape). This method is NOT affected by workspaces
*
* @param data
* @return
* @param data the shape of the buffer to create
* @return the created buffer
*/
public static DataBuffer createBufferDetached(int[] data) {
DataBuffer ret;
ret = DATA_BUFFER_FACTORY_INSTANCE.createInt(data);
return ret;
return DATA_BUFFER_FACTORY_INSTANCE.createInt(data);
}
/**
* Create a buffer equal of length prod(shape). This method is NOT affected by workspaces
*
* @param data
* @return
* @param data the shape of the buffer to create
* @return the created buffer
*/
public static DataBuffer createBufferDetached(long[] data) {
DataBuffer ret;
ret = DATA_BUFFER_FACTORY_INSTANCE.createLong(data);
return ret;
return DATA_BUFFER_FACTORY_INSTANCE.createLong(data);
}
/**
@ -1464,9 +1427,7 @@ public class Nd4j {
* See {@link #createBuffer(DataType dataType, long length, boolean initialize) with default datatype.
*/
public static DataBuffer createBuffer(long length, boolean initialize) {
DataBuffer ret = createBuffer(Nd4j.dataType(), length, initialize);
return ret;
return createBuffer(Nd4j.dataType(), length, initialize);
}
/**
@ -1524,6 +1485,11 @@ public class Nd4j {
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createDouble(data) : DATA_BUFFER_FACTORY_INSTANCE.createDouble(data, Nd4j.getMemoryManager().getCurrentWorkspace());
}
// refactoring of duplicate code.
private static DataBuffer getDataBuffer(int length, DataType dataType){
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
}
/**
* Create a buffer based on the data of the underlying java array with the specified type..
* @param data underlying java array
@ -1531,16 +1497,16 @@ public class Nd4j {
* @return created buffer,
*/
public static DataBuffer createTypedBuffer(double[] data, DataType dataType) {
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data);
return buffer;
}
/**
* See {@link #createTypedBuffer(float[], DataType)}
* See {@link #createTypedBuffer(double[], DataType)}
*/
public static DataBuffer createTypedBuffer(float[] data, DataType dataType) {
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data);
return buffer;
}
@ -1549,7 +1515,7 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)}
*/
public static DataBuffer createTypedBuffer(int[] data, DataType dataType) {
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data);
return buffer;
}
@ -1558,7 +1524,7 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)}
*/
public static DataBuffer createTypedBuffer(long[] data, DataType dataType) {
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data);
return buffer;
}
@ -1567,7 +1533,7 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)}
*/
public static DataBuffer createTypedBuffer(short[] data, DataType dataType) {
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data);
return buffer;
}
@ -1576,7 +1542,7 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)}
*/
public static DataBuffer createTypedBuffer(byte[] data, DataType dataType) {
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data);
return buffer;
}
@ -1585,11 +1551,17 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)}
*/
public static DataBuffer createTypedBuffer(boolean[] data, DataType dataType) {
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data);
return buffer;
}
// refactoring of duplicate code.
private static DataBuffer getDataBuffer(int length, DataType dataType, MemoryWorkspace workspace){
return workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false, workspace);
}
/**
* Create a buffer based on the data of the underlying java array, specified type and workspace
* @param data underlying java array
@ -1598,61 +1570,18 @@ public class Nd4j {
* @return created buffer,
*/
public static DataBuffer createTypedBuffer(double[] data, DataType dataType, MemoryWorkspace workspace) {
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
//val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
DataBuffer buffer = getDataBuffer(data.length, dataType, workspace);
buffer.setData(data);
return buffer;
}
/**
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
* See {@link #createTypedBuffer(double[], DataType, MemoryWorkspace)}
*/
public static DataBuffer createTypedBuffer(float[] data, DataType dataType, MemoryWorkspace workspace) {
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
buffer.setData(data);
return buffer;
}
/**
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
*/
public static DataBuffer createTypedBuffer(int[] data, DataType dataType, MemoryWorkspace workspace) {
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
buffer.setData(data);
return buffer;
}
/**
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
*/
public static DataBuffer createTypedBuffer(long[] data, DataType dataType, MemoryWorkspace workspace) {
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
buffer.setData(data);
return buffer;
}
/**
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
*/
public static DataBuffer createTypedBuffer(short[] data, DataType dataType, MemoryWorkspace workspace) {
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
buffer.setData(data);
return buffer;
}
/**
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
*/
public static DataBuffer createTypedBuffer(byte[] data, DataType dataType, MemoryWorkspace workspace) {
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
buffer.setData(data);
return buffer;
}
/**
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
*/
public static DataBuffer createTypedBuffer(boolean[] data, DataType dataType, MemoryWorkspace workspace) {
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
//val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
DataBuffer buffer = getDataBuffer(data.length, dataType, workspace);
buffer.setData(data);
return buffer;
}
@ -1661,7 +1590,7 @@ public class Nd4j {
* Create am uninitialized buffer based on the data of the underlying java array and specified type.
* @param data underlying java array
* @param dataType specified type.
* @return
* @return the created buffer.
*/
public static DataBuffer createTypedBufferDetached(double[] data, DataType dataType) {
val buffer = DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false);
@ -1731,14 +1660,6 @@ public class Nd4j {
INSTANCE = factory;
}
/**
* Set the factory instance for sparse INDArray creation.
* @param factory new INDArray factory
*/
public static void setSparseFactory(NDArrayFactory factory) {
SPARSE_INSTANCE = factory;
}
/**
* Returns the ordering of the ndarrays
*
@ -1810,15 +1731,6 @@ public class Nd4j {
return SPARSE_BLAS_WRAPPER_INSTANCE;
}
/**
* Sets the global blas wrapper
*
* @param factory
*/
public static void setBlasWrapper(BlasWrapper factory) {
BLAS_WRAPPER_INSTANCE = factory;
}
/**
* Sort an ndarray along a particular dimension.<br>
* Note that the input array is modified in-place.
@ -1899,7 +1811,7 @@ public class Nd4j {
*
* @param ndarray array to sort
* @param ascending true for ascending, false for descending
* @return
* @return the sorted ndarray
*/
public static INDArray sort(INDArray ndarray, boolean ascending) {
return getNDArrayFactory().sort(ndarray, !ascending);
@ -1931,20 +1843,18 @@ public class Nd4j {
* @param in 2d array to sort
* @param colIdx The column to sort on
* @param ascending true if smallest-to-largest; false if largest-to-smallest
* @return
* @return the sorted ndarray
*/
@SuppressWarnings("Duplicates")
public static INDArray sortRows(final INDArray in, final int colIdx, final boolean ascending) {
if (in.rank() != 2)
throw new IllegalArgumentException("Cannot sort rows on non-2d matrix");
if (colIdx < 0 || colIdx >= in.columns())
throw new IllegalArgumentException("Cannot sort on values in column " + colIdx + ", nCols=" + in.columns());
if (in.rows() > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
INDArray out = Nd4j.create(in.dataType(), in.shape());
int nRows = (int) in.rows();
ArrayList<Integer> list = new ArrayList<Integer>(nRows);
int nRows = in.rows();
ArrayList<Integer> list = new ArrayList<>(nRows);
for (int i = 0; i < nRows; i++)
list.add(i);
Collections.sort(list, new Comparator<Integer>() {
@ -1976,19 +1886,17 @@ public class Nd4j {
* @param in 2d array to sort
* @param rowIdx The row to sort on
* @param ascending true if smallest-to-largest; false if largest-to-smallest
* @return
* @return the sorted array.
*/
@SuppressWarnings("Duplicates")
public static INDArray sortColumns(final INDArray in, final int rowIdx, final boolean ascending) {
if (in.rank() != 2)
throw new IllegalArgumentException("Cannot sort columns on non-2d matrix");
if (rowIdx < 0 || rowIdx >= in.rows())
throw new IllegalArgumentException("Cannot sort on values in row " + rowIdx + ", nRows=" + in.rows());
if (in.columns() > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
INDArray out = Nd4j.create(in.shape());
int nCols = (int) in.columns();
int nCols = in.columns();
ArrayList<Integer> list = new ArrayList<>(nCols);
for (int i = 0; i < nCols; i++)
list.add(i);
@ -2041,7 +1949,7 @@ public class Nd4j {
if (dtype.isIntType()) {
long upper = lower + num * step;
return linspaceWithCustomOpByRange((long) lower, upper, num, (long) step, dtype);
return linspaceWithCustomOpByRange( lower, upper, num, step, dtype);
} else if (dtype.isFPType()) {
return Nd4j.getExecutioner().exec(new Linspace((double) lower, num, (double)step, dtype));
}
@ -2062,7 +1970,6 @@ public class Nd4j {
return linspace(lower, upper, num, Nd4j.dataType());
}
/**
* Generate a linearly spaced vector
*
@ -2077,7 +1984,7 @@ public class Nd4j {
return Nd4j.scalar(dtype, lower);
}
if (dtype.isIntType()) {
return linspaceWithCustomOp((long)lower, (long)upper, (int)num, dtype);
return linspaceWithCustomOp(lower, upper, (int)num, dtype);
} else if (dtype.isFPType()) {
return linspace((double) lower, (double)upper, (int) num, dtype);
}
@ -2182,8 +2089,7 @@ public class Nd4j {
* these ndarrays
*/
public static INDArray toFlattened(Collection<INDArray> matrices) {
INDArray ret = INSTANCE.toFlattened(matrices);
return ret;
return INSTANCE.toFlattened(matrices);
}
/**
@ -2194,27 +2100,7 @@ public class Nd4j {
* these ndarrays
*/
public static INDArray toFlattened(char order, Collection<INDArray> matrices) {
INDArray ret = INSTANCE.toFlattened(order, matrices);
return ret;
}
/**
* Create a long row vector of all of the given ndarrays
* @param matrices the matrices to create the flattened ndarray for
* @return the flattened representation of
* these ndarrays
*/
public static INDArray toFlattened(int length, Iterator<? extends INDArray>... matrices) {
INDArray ret = INSTANCE.toFlattened(length, matrices);
return ret;
}
/**
* Returns a column vector where each entry is the nth bilinear
* product of the nth slices of the two tensors.
*/
public static INDArray bilinearProducts(INDArray curr, INDArray in) {
return INSTANCE.bilinearProducts(curr, in);
return INSTANCE.toFlattened(order, matrices);
}
/**
@ -2243,32 +2129,28 @@ public class Nd4j {
* Create the identity ndarray
*
* @param n the number for the identity
* @return
* @return the identity array
*/
public static INDArray eye(long n) {
INDArray ret = INSTANCE.eye(n);
return ret;
return INSTANCE.eye(n);
}
/**
* Rotate a matrix 90 degrees
*
* @param toRotate the matrix to rotate
* @return the rotated matrix
*/
public static void rot90(INDArray toRotate) {
INSTANCE.rot90(toRotate);
}
/**
* Write NDArray to a text file
*
* @param filePath
* @param filePath path to write to
* @param split the split separator, defaults to ","
* @deprecated custom col separators are no longer supported; uses ","
* @param precision digits after the decimal point
* @deprecated Precision is no longer used.
* @deprecated Precision is no longer used. Split is no longer used.
* Defaults to scientific notation with 18 digits after the decimal
* Use {@link #writeTxt(INDArray, String)}
*/
@ -2279,10 +2161,10 @@ public class Nd4j {
/**
* Write NDArray to a text file
*
* @param write
* @param filePath
* @param precision
* @deprecated Precision is no longer used.
* @param write array to write
* @param filePath path to write to
* @param precision Precision is no longer used.
* @deprecated
* Defaults to scientific notation with 18 digits after the decimal
* Use {@link #writeTxt(INDArray, String)}
*/
@ -2293,9 +2175,9 @@ public class Nd4j {
/**
* Write NDArray to a text file
*
* @param write
* @param filePath
* @param split
* @param write array to write
* @param filePath path to write to
* @param split the split separator, defaults to ","
* @deprecated custom col and higher dimension separators are no longer supported; uses ","
* Use {@link #writeTxt(INDArray, String)}
*/
@ -2311,69 +2193,27 @@ public class Nd4j {
*/
public static void writeTxt(INDArray write, String filePath) {
try {
String toWrite = writeStringForArray(write, "0.000000000000000000E0");
FileUtils.writeStringToFile(new File(filePath), toWrite);
String toWrite = writeStringForArray(write);
FileUtils.writeStringToFile(new File(filePath), toWrite, (String)null, false);
} catch (IOException e) {
throw new RuntimeException("Error writing output", e);
}
}
/**
* @deprecated custom col separators are no longer supported; uses ","
* @deprecated precision can no longer be specified. The array is written in scientific notation.
* @see #writeTxtString(INDArray, OutputStream)
*/
public static void writeTxtString(INDArray write, OutputStream os, String split, int precision) {
writeTxtString(write,os);
}
/**
* @deprecated precision can no longer be specified. The array is written in scientific notation.
* @see #writeTxtString(INDArray, OutputStream)
*/
@Deprecated
public static void writeTxtString(INDArray write, OutputStream os, int precision) {
writeTxtString(write,os);
}
/**
* @deprecated column separator can longer be specified; Uses ","
* @see #writeTxtString(INDArray, OutputStream)
*/
@Deprecated
public static void writeTxtString(INDArray write, OutputStream os, String split) {
writeTxtString(write, os);
}
/**
* Write ndarray as text to output stream
* @param write Array to write
* @param os stream to write too.
*/
public static void writeTxtString(INDArray write, OutputStream os) {
try {
// default format is "0.000000000000000000E0"
String toWrite = writeStringForArray(write, "0.000000000000000000E0");
os.write(toWrite.getBytes());
} catch (IOException e) {
throw new RuntimeException("Error writing output", e);
}
}
private static String writeStringForArray(INDArray write, String format) {
private static String writeStringForArray(INDArray write) {
if(write.isView() || !Shape.hasDefaultStridesForShape(write))
write = write.dup();
if (format.isEmpty()) format = "0.000000000000000000E0";
String lineOne = "{\n";
String lineTwo = "\"filefrom\": \"dl4j\",\n";
String lineThree = "\"ordering\": \"" + write.ordering() + "\",\n";
String lineFour = "\"shape\":\t" + java.util.Arrays.toString(write.shape()) + ",\n";
String lineFive = "\"data\":\n";
String fileData = new NDArrayStrings(",", format).format(write, false);
String fileEnd = "\n}\n";
String fileBegin = lineOne + lineTwo + lineThree + lineFour + lineFive;
String fileContents = fileBegin + fileData + fileEnd;
return fileContents;
String format = "0.000000000000000000E0";
return new StringBuilder()
.append("{\n")
.append("\"filefrom\": \"dl4j\",\n")
.append( "\"ordering\": \"").append(write.ordering()).append("\",\n")
.append("\"shape\":\t").append( java.util.Arrays.toString(write.shape())).append(",\n")
.append("\"data\":\n")
.append(new NDArrayStrings(",", format).format(write, false))
.append("\n}\n")
.toString();
}