nd4j refactoring. (#103)

Restored writeTxt per slack DM.
master
Robert Altena 2019-08-07 19:31:48 +09:00 committed by GitHub
parent edb71bf46f
commit f8615e0ef0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 443 deletions

View File

@ -1,144 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ndarray;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.LineIterator;
import org.nd4j.linalg.factory.Nd4j;
import java.io.File;
import java.io.IOException;
/**
* Created by susaneraly on 6/16/16.
*/
@Deprecated
public class NdArrayJSONReader {
public INDArray read(File jsonFile) {
INDArray result = this.loadNative(jsonFile);
if (result == null) {
//Must write support for parsing/normal json parsing - which will be inefficient
this.loadNonNative(jsonFile);
}
return result;
}
private INDArray loadNative(File jsonFile) {
/*
We could dump an ndarray to a file with the tostring (since that is valid json) and use put/get to parse it as json
But here we leverage our information of the tostring method to be more efficient
With our current toString format we use tads along dimension (rank-1,rank-2) to write to the array in two dimensional chunks at a time.
This is more efficient than setting each value at a time with putScalar.
This also means we can read the file one line at a time instead of loading the whole thing into memory
Future work involves enhancing the write json method to provide more features to make the load more efficient
*/
int lineNum = 0;
int rowNum = 0;
int tensorNum = 0;
char theOrder = 'c';
int[] theShape = {1, 1};
int rank = 0;
double[][] subsetArr = {{0.0, 0.0}, {0.0, 0.0}};
INDArray newArr = Nd4j.zeros(2, 2);
try {
LineIterator it = FileUtils.lineIterator(jsonFile);
try {
while (it.hasNext()) {
String line = it.nextLine();
lineNum++;
line = line.replaceAll("\\s", "");
if (line.equals("") || line.equals("}"))
continue;
// is it from dl4j?
if (lineNum == 2) {
String[] lineArr = line.split(":");
String fileSource = lineArr[1].replaceAll("\\W", "");
if (!fileSource.equals("dl4j"))
return null;
}
// parse ordering
if (lineNum == 3) {
String[] lineArr = line.split(":");
theOrder = lineArr[1].replace("\\W", "").charAt(0);
continue;
}
// parse shape
if (lineNum == 4) {
String[] lineArr = line.split(":");
String dropJsonComma = lineArr[1].split("]")[0];
String[] shapeString = dropJsonComma.replace("[", "").split(",");
rank = shapeString.length;
theShape = new int[rank];
for (int i = 0; i < rank; i++) {
try {
theShape[i] = Integer.parseInt(shapeString[i]);
} catch (NumberFormatException nfe) {
} ;
}
subsetArr = new double[theShape[rank - 2]][theShape[rank - 1]];
newArr = Nd4j.zeros(theShape, theOrder);
continue;
}
//parse data
if (lineNum > 5) {
String[] entries =
line.replace("\\],", "").replaceAll("\\[", "").replaceAll("\\]", "").split(",");
for (int i = 0; i < theShape[rank - 1]; i++) {
try {
subsetArr[rowNum][i] = Double.parseDouble(entries[i]);
} catch (NumberFormatException nfe) {
}
}
rowNum++;
if (rowNum == theShape[rank - 2]) {
INDArray subTensor = Nd4j.create(subsetArr);
newArr.tensorAlongDimension(tensorNum, rank - 1, rank - 2).addi(subTensor);
rowNum = 0;
tensorNum++;
}
}
}
} finally {
LineIterator.closeQuietly(it);
}
} catch (IOException e) {
throw new RuntimeException("Error reading input", e);
}
return newArr;
}
private INDArray loadNonNative(File jsonFile) {
/* WIP
JSONTokener tokener = new JSONTokener(new FileReader("test.json"));
JSONObject obj = new JSONObject(tokener);
JSONArray objArr = obj.optJSONArray("shape");
int rank = objArr.length();
int[] theShape = new int[rank];
int rows = 1;
for (int i = 0; i < rank; ++i) {
theShape[i] = objArr.optInt(i);
if (i != objArr.length() - 1)
rows *= theShape[i];
}
*/
System.out.println("API_Error: Current support only for files written from dl4j");
return null;
}
}

View File

@ -1,54 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ndarray;
import org.apache.commons.io.FileUtils;
import java.io.File;
import java.io.IOException;
/**
* Created by susaneraly on 6/18/16.
*/
@Deprecated
public class NdArrayJSONWriter {
private NdArrayJSONWriter() {}
/**
*
* @param thisnD
* @param filePath
*/
public static void write(INDArray thisnD, String filePath) {
//TO DO: Add precision support in toString
//TO DO: Write to file one line at time
String lineOne = "{\n";
String lineTwo = "\"filefrom\": \"dl4j\",\n";
String lineThree = "\"ordering\": \"" + thisnD.ordering() + "\",\n";
String lineFour = "\"shape\":\t" + java.util.Arrays.toString(thisnD.shape()) + ",\n";
String lineFive = "\"data\":\n";
String fileData = thisnD.toString();
String fileEnd = "\n}\n";
String fileBegin = lineOne + lineTwo + lineThree + lineFour + lineFive;
try {
FileUtils.writeStringToFile(new File(filePath), fileBegin + fileData + fileEnd);
} catch (IOException e) {
throw new RuntimeException("Error writing output", e);
}
}
}

View File

@ -75,7 +75,6 @@ import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.convolution.ConvolutionInstance; import org.nd4j.linalg.convolution.ConvolutionInstance;
import org.nd4j.linalg.convolution.DefaultConvolutionInstance; import org.nd4j.linalg.convolution.DefaultConvolutionInstance;
import org.nd4j.linalg.env.EnvironmentalAction; import org.nd4j.linalg.env.EnvironmentalAction;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4JUnknownDataTypeException; import org.nd4j.linalg.exception.ND4JUnknownDataTypeException;
import org.nd4j.linalg.factory.Nd4jBackend.NoAvailableBackendException; import org.nd4j.linalg.factory.Nd4jBackend.NoAvailableBackendException;
@ -1073,7 +1072,7 @@ public class Nd4j {
return ret; return ret;
} }
protected static Indexer getIndexerByType(Pointer pointer, DataType dataType) { private static Indexer getIndexerByType(Pointer pointer, DataType dataType) {
switch (dataType) { switch (dataType) {
case UINT64: case UINT64:
case LONG: case LONG:
@ -1113,40 +1112,7 @@ public class Nd4j {
* @return the created buffer * @return the created buffer
*/ */
public static DataBuffer createBuffer(@NonNull Pointer pointer, long length, @NonNull DataType dataType) { public static DataBuffer createBuffer(@NonNull Pointer pointer, long length, @NonNull DataType dataType) {
Pointer nPointer = null; Pointer nPointer = getPointer(pointer, dataType);
//TODO: remove dupplicate code.
switch (dataType) {
case LONG:
nPointer = new LongPointer(pointer);
break;
case INT:
nPointer = new IntPointer(pointer);
break;
case SHORT:
nPointer = new ShortPointer(pointer);
break;
case BYTE:
nPointer = new BytePointer(pointer);
break;
case UBYTE:
nPointer = new BytePointer(pointer);
break;
case BOOL:
nPointer = new BooleanPointer(pointer);
break;
case FLOAT:
nPointer = new FloatPointer(pointer);
break;
case HALF:
nPointer = new ShortPointer(pointer);
break;
case DOUBLE:
nPointer = new DoublePointer(pointer);
break;
default:
throw new UnsupportedOperationException("Unsupported data type: " + dataType);
}
return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, dataType, length, getIndexerByType(nPointer, dataType)); return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, dataType, length, getIndexerByType(nPointer, dataType));
} }
@ -1161,7 +1127,12 @@ public class Nd4j {
* @return the created buffer * @return the created buffer
*/ */
public static DataBuffer createBuffer(@NonNull Pointer pointer, @NonNull Pointer devicePointer, long length, @NonNull DataType dataType) { public static DataBuffer createBuffer(@NonNull Pointer pointer, @NonNull Pointer devicePointer, long length, @NonNull DataType dataType) {
Pointer nPointer = null; Pointer nPointer = getPointer(pointer, dataType);
return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, devicePointer, dataType, length, getIndexerByType(nPointer, dataType));
}
private static Pointer getPointer(@NonNull Pointer pointer, @NonNull DataType dataType ){
Pointer nPointer;
switch (dataType) { switch (dataType) {
case UINT64: case UINT64:
case LONG: case LONG:
@ -1198,7 +1169,7 @@ public class Nd4j {
throw new UnsupportedOperationException("Unsupported data type: " + dataType); throw new UnsupportedOperationException("Unsupported data type: " + dataType);
} }
return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, devicePointer, dataType, length, getIndexerByType(nPointer, dataType)); return nPointer;
} }
/** /**
@ -1208,9 +1179,8 @@ public class Nd4j {
* @return the created buffer * @return the created buffer
*/ */
public static DataBuffer createBuffer(float[] data, long offset) { public static DataBuffer createBuffer(float[] data, long offset) {
val ndata = Arrays.copyOfRange(data, (int) offset, data.length); return createTypedBuffer(Arrays.copyOfRange(data, (int) offset, data.length),
DataBuffer ret = createTypedBuffer(ndata, DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace()); DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
return ret;
} }
/** /**
@ -1220,9 +1190,8 @@ public class Nd4j {
* @return the created buffer * @return the created buffer
*/ */
public static DataBuffer createBuffer(double[] data, long offset) { public static DataBuffer createBuffer(double[] data, long offset) {
val ndata = Arrays.copyOfRange(data, (int) offset, data.length); return createTypedBuffer(Arrays.copyOfRange(data, (int) offset, data.length),
DataBuffer ret = createTypedBuffer(ndata, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
return ret;
} }
/** /**
@ -1294,7 +1263,7 @@ public class Nd4j {
* *
* @param shape the shape of the buffer to create * @param shape the shape of the buffer to create
* @param type the opType to create * @param type the opType to create
* @return * @return the created buffer.
*/ */
public static DataBuffer createBufferDetached(int[] shape, DataType type) { public static DataBuffer createBufferDetached(int[] shape, DataType type) {
return createBufferDetachedImpl( ArrayUtil.prodLong(shape), type); return createBufferDetachedImpl( ArrayUtil.prodLong(shape), type);
@ -1354,7 +1323,7 @@ public class Nd4j {
* @param buffer the buffer to create from * @param buffer the buffer to create from
* @param type the opType of buffer to create * @param type the opType of buffer to create
* @param length the length of the buffer * @param length the length of the buffer
* @return * @return the created buffer
*/ */
public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length) { public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length) {
switch (type) { switch (type) {
@ -1407,33 +1376,27 @@ public class Nd4j {
* @return the created buffer * @return the created buffer
*/ */
public static DataBuffer createBuffer(long[] data) { public static DataBuffer createBuffer(long[] data) {
DataBuffer ret; return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createLong(data) : DATA_BUFFER_FACTORY_INSTANCE.createLong(data, Nd4j.getMemoryManager().getCurrentWorkspace());
ret = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createLong(data) : DATA_BUFFER_FACTORY_INSTANCE.createLong(data, Nd4j.getMemoryManager().getCurrentWorkspace());
return ret;
} }
/** /**
* Create a buffer equal of length prod(shape). This method is NOT affected by workspaces * Create a buffer equal of length prod(shape). This method is NOT affected by workspaces
* *
* @param data * @param data the shape of the buffer to create
* @return * @return the created buffer
*/ */
public static DataBuffer createBufferDetached(int[] data) { public static DataBuffer createBufferDetached(int[] data) {
DataBuffer ret; return DATA_BUFFER_FACTORY_INSTANCE.createInt(data);
ret = DATA_BUFFER_FACTORY_INSTANCE.createInt(data);
return ret;
} }
/** /**
* Create a buffer equal of length prod(shape). This method is NOT affected by workspaces * Create a buffer equal of length prod(shape). This method is NOT affected by workspaces
* *
* @param data * @param data the shape of the buffer to create
* @return * @return the created buffer
*/ */
public static DataBuffer createBufferDetached(long[] data) { public static DataBuffer createBufferDetached(long[] data) {
DataBuffer ret; return DATA_BUFFER_FACTORY_INSTANCE.createLong(data);
ret = DATA_BUFFER_FACTORY_INSTANCE.createLong(data);
return ret;
} }
/** /**
@ -1464,9 +1427,7 @@ public class Nd4j {
* See {@link #createBuffer(DataType dataType, long length, boolean initialize) with default datatype. * See {@link #createBuffer(DataType dataType, long length, boolean initialize) with default datatype.
*/ */
public static DataBuffer createBuffer(long length, boolean initialize) { public static DataBuffer createBuffer(long length, boolean initialize) {
DataBuffer ret = createBuffer(Nd4j.dataType(), length, initialize); return createBuffer(Nd4j.dataType(), length, initialize);
return ret;
} }
/** /**
@ -1524,6 +1485,11 @@ public class Nd4j {
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createDouble(data) : DATA_BUFFER_FACTORY_INSTANCE.createDouble(data, Nd4j.getMemoryManager().getCurrentWorkspace()); return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createDouble(data) : DATA_BUFFER_FACTORY_INSTANCE.createDouble(data, Nd4j.getMemoryManager().getCurrentWorkspace());
} }
// refactoring of duplicate code.
private static DataBuffer getDataBuffer(int length, DataType dataType){
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
}
/** /**
* Create a buffer based on the data of the underlying java array with the specified type.. * Create a buffer based on the data of the underlying java array with the specified type..
* @param data underlying java array * @param data underlying java array
@ -1531,16 +1497,16 @@ public class Nd4j {
* @return created buffer, * @return created buffer,
*/ */
public static DataBuffer createTypedBuffer(double[] data, DataType dataType) { public static DataBuffer createTypedBuffer(double[] data, DataType dataType) {
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data); buffer.setData(data);
return buffer; return buffer;
} }
/** /**
* See {@link #createTypedBuffer(float[], DataType)} * See {@link #createTypedBuffer(double[], DataType)}
*/ */
public static DataBuffer createTypedBuffer(float[] data, DataType dataType) { public static DataBuffer createTypedBuffer(float[] data, DataType dataType) {
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data); buffer.setData(data);
return buffer; return buffer;
} }
@ -1549,7 +1515,7 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)} * See {@link #createTypedBuffer(float[], DataType)}
*/ */
public static DataBuffer createTypedBuffer(int[] data, DataType dataType) { public static DataBuffer createTypedBuffer(int[] data, DataType dataType) {
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data); buffer.setData(data);
return buffer; return buffer;
} }
@ -1558,7 +1524,7 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)} * See {@link #createTypedBuffer(float[], DataType)}
*/ */
public static DataBuffer createTypedBuffer(long[] data, DataType dataType) { public static DataBuffer createTypedBuffer(long[] data, DataType dataType) {
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data); buffer.setData(data);
return buffer; return buffer;
} }
@ -1567,7 +1533,7 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)} * See {@link #createTypedBuffer(float[], DataType)}
*/ */
public static DataBuffer createTypedBuffer(short[] data, DataType dataType) { public static DataBuffer createTypedBuffer(short[] data, DataType dataType) {
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data); buffer.setData(data);
return buffer; return buffer;
} }
@ -1576,7 +1542,7 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)} * See {@link #createTypedBuffer(float[], DataType)}
*/ */
public static DataBuffer createTypedBuffer(byte[] data, DataType dataType) { public static DataBuffer createTypedBuffer(byte[] data, DataType dataType) {
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data); buffer.setData(data);
return buffer; return buffer;
} }
@ -1585,11 +1551,17 @@ public class Nd4j {
* See {@link #createTypedBuffer(float[], DataType)} * See {@link #createTypedBuffer(float[], DataType)}
*/ */
public static DataBuffer createTypedBuffer(boolean[] data, DataType dataType) { public static DataBuffer createTypedBuffer(boolean[] data, DataType dataType) {
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace()); DataBuffer buffer = getDataBuffer(data.length, dataType);
buffer.setData(data); buffer.setData(data);
return buffer; return buffer;
} }
// refactoring of duplicate code.
private static DataBuffer getDataBuffer(int length, DataType dataType, MemoryWorkspace workspace){
return workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false, workspace);
}
/** /**
* Create a buffer based on the data of the underlying java array, specified type and workspace * Create a buffer based on the data of the underlying java array, specified type and workspace
* @param data underlying java array * @param data underlying java array
@ -1598,61 +1570,18 @@ public class Nd4j {
* @return created buffer, * @return created buffer,
*/ */
public static DataBuffer createTypedBuffer(double[] data, DataType dataType, MemoryWorkspace workspace) { public static DataBuffer createTypedBuffer(double[] data, DataType dataType, MemoryWorkspace workspace) {
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace); //val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
DataBuffer buffer = getDataBuffer(data.length, dataType, workspace);
buffer.setData(data); buffer.setData(data);
return buffer; return buffer;
} }
/** /**
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)} * See {@link #createTypedBuffer(double[], DataType, MemoryWorkspace)}
*/ */
public static DataBuffer createTypedBuffer(float[] data, DataType dataType, MemoryWorkspace workspace) { public static DataBuffer createTypedBuffer(float[] data, DataType dataType, MemoryWorkspace workspace) {
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace); //val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
buffer.setData(data); DataBuffer buffer = getDataBuffer(data.length, dataType, workspace);
return buffer;
}
/**
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
*/
public static DataBuffer createTypedBuffer(int[] data, DataType dataType, MemoryWorkspace workspace) {
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
buffer.setData(data);
return buffer;
}
/**
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
*/
public static DataBuffer createTypedBuffer(long[] data, DataType dataType, MemoryWorkspace workspace) {
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
buffer.setData(data);
return buffer;
}
/**
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
*/
public static DataBuffer createTypedBuffer(short[] data, DataType dataType, MemoryWorkspace workspace) {
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
buffer.setData(data);
return buffer;
}
/**
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
*/
public static DataBuffer createTypedBuffer(byte[] data, DataType dataType, MemoryWorkspace workspace) {
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
buffer.setData(data);
return buffer;
}
/**
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
*/
public static DataBuffer createTypedBuffer(boolean[] data, DataType dataType, MemoryWorkspace workspace) {
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
buffer.setData(data); buffer.setData(data);
return buffer; return buffer;
} }
@ -1661,7 +1590,7 @@ public class Nd4j {
* Create am uninitialized buffer based on the data of the underlying java array and specified type. * Create am uninitialized buffer based on the data of the underlying java array and specified type.
* @param data underlying java array * @param data underlying java array
* @param dataType specified type. * @param dataType specified type.
* @return * @return the created buffer.
*/ */
public static DataBuffer createTypedBufferDetached(double[] data, DataType dataType) { public static DataBuffer createTypedBufferDetached(double[] data, DataType dataType) {
val buffer = DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false); val buffer = DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false);
@ -1731,14 +1660,6 @@ public class Nd4j {
INSTANCE = factory; INSTANCE = factory;
} }
/**
* Set the factory instance for sparse INDArray creation.
* @param factory new INDArray factory
*/
public static void setSparseFactory(NDArrayFactory factory) {
SPARSE_INSTANCE = factory;
}
/** /**
* Returns the ordering of the ndarrays * Returns the ordering of the ndarrays
* *
@ -1810,15 +1731,6 @@ public class Nd4j {
return SPARSE_BLAS_WRAPPER_INSTANCE; return SPARSE_BLAS_WRAPPER_INSTANCE;
} }
/**
* Sets the global blas wrapper
*
* @param factory
*/
public static void setBlasWrapper(BlasWrapper factory) {
BLAS_WRAPPER_INSTANCE = factory;
}
/** /**
* Sort an ndarray along a particular dimension.<br> * Sort an ndarray along a particular dimension.<br>
* Note that the input array is modified in-place. * Note that the input array is modified in-place.
@ -1899,7 +1811,7 @@ public class Nd4j {
* *
* @param ndarray array to sort * @param ndarray array to sort
* @param ascending true for ascending, false for descending * @param ascending true for ascending, false for descending
* @return * @return the sorted ndarray
*/ */
public static INDArray sort(INDArray ndarray, boolean ascending) { public static INDArray sort(INDArray ndarray, boolean ascending) {
return getNDArrayFactory().sort(ndarray, !ascending); return getNDArrayFactory().sort(ndarray, !ascending);
@ -1931,20 +1843,18 @@ public class Nd4j {
* @param in 2d array to sort * @param in 2d array to sort
* @param colIdx The column to sort on * @param colIdx The column to sort on
* @param ascending true if smallest-to-largest; false if largest-to-smallest * @param ascending true if smallest-to-largest; false if largest-to-smallest
* @return * @return the sorted ndarray
*/ */
@SuppressWarnings("Duplicates")
public static INDArray sortRows(final INDArray in, final int colIdx, final boolean ascending) { public static INDArray sortRows(final INDArray in, final int colIdx, final boolean ascending) {
if (in.rank() != 2) if (in.rank() != 2)
throw new IllegalArgumentException("Cannot sort rows on non-2d matrix"); throw new IllegalArgumentException("Cannot sort rows on non-2d matrix");
if (colIdx < 0 || colIdx >= in.columns()) if (colIdx < 0 || colIdx >= in.columns())
throw new IllegalArgumentException("Cannot sort on values in column " + colIdx + ", nCols=" + in.columns()); throw new IllegalArgumentException("Cannot sort on values in column " + colIdx + ", nCols=" + in.columns());
if (in.rows() > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
INDArray out = Nd4j.create(in.dataType(), in.shape()); INDArray out = Nd4j.create(in.dataType(), in.shape());
int nRows = (int) in.rows(); int nRows = in.rows();
ArrayList<Integer> list = new ArrayList<Integer>(nRows); ArrayList<Integer> list = new ArrayList<>(nRows);
for (int i = 0; i < nRows; i++) for (int i = 0; i < nRows; i++)
list.add(i); list.add(i);
Collections.sort(list, new Comparator<Integer>() { Collections.sort(list, new Comparator<Integer>() {
@ -1976,19 +1886,17 @@ public class Nd4j {
* @param in 2d array to sort * @param in 2d array to sort
* @param rowIdx The row to sort on * @param rowIdx The row to sort on
* @param ascending true if smallest-to-largest; false if largest-to-smallest * @param ascending true if smallest-to-largest; false if largest-to-smallest
* @return * @return the sorted array.
*/ */
@SuppressWarnings("Duplicates")
public static INDArray sortColumns(final INDArray in, final int rowIdx, final boolean ascending) { public static INDArray sortColumns(final INDArray in, final int rowIdx, final boolean ascending) {
if (in.rank() != 2) if (in.rank() != 2)
throw new IllegalArgumentException("Cannot sort columns on non-2d matrix"); throw new IllegalArgumentException("Cannot sort columns on non-2d matrix");
if (rowIdx < 0 || rowIdx >= in.rows()) if (rowIdx < 0 || rowIdx >= in.rows())
throw new IllegalArgumentException("Cannot sort on values in row " + rowIdx + ", nRows=" + in.rows()); throw new IllegalArgumentException("Cannot sort on values in row " + rowIdx + ", nRows=" + in.rows());
if (in.columns() > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
INDArray out = Nd4j.create(in.shape()); INDArray out = Nd4j.create(in.shape());
int nCols = (int) in.columns(); int nCols = in.columns();
ArrayList<Integer> list = new ArrayList<>(nCols); ArrayList<Integer> list = new ArrayList<>(nCols);
for (int i = 0; i < nCols; i++) for (int i = 0; i < nCols; i++)
list.add(i); list.add(i);
@ -2041,7 +1949,7 @@ public class Nd4j {
if (dtype.isIntType()) { if (dtype.isIntType()) {
long upper = lower + num * step; long upper = lower + num * step;
return linspaceWithCustomOpByRange((long) lower, upper, num, (long) step, dtype); return linspaceWithCustomOpByRange( lower, upper, num, step, dtype);
} else if (dtype.isFPType()) { } else if (dtype.isFPType()) {
return Nd4j.getExecutioner().exec(new Linspace((double) lower, num, (double)step, dtype)); return Nd4j.getExecutioner().exec(new Linspace((double) lower, num, (double)step, dtype));
} }
@ -2062,7 +1970,6 @@ public class Nd4j {
return linspace(lower, upper, num, Nd4j.dataType()); return linspace(lower, upper, num, Nd4j.dataType());
} }
/** /**
* Generate a linearly spaced vector * Generate a linearly spaced vector
* *
@ -2077,7 +1984,7 @@ public class Nd4j {
return Nd4j.scalar(dtype, lower); return Nd4j.scalar(dtype, lower);
} }
if (dtype.isIntType()) { if (dtype.isIntType()) {
return linspaceWithCustomOp((long)lower, (long)upper, (int)num, dtype); return linspaceWithCustomOp(lower, upper, (int)num, dtype);
} else if (dtype.isFPType()) { } else if (dtype.isFPType()) {
return linspace((double) lower, (double)upper, (int) num, dtype); return linspace((double) lower, (double)upper, (int) num, dtype);
} }
@ -2182,8 +2089,7 @@ public class Nd4j {
* these ndarrays * these ndarrays
*/ */
public static INDArray toFlattened(Collection<INDArray> matrices) { public static INDArray toFlattened(Collection<INDArray> matrices) {
INDArray ret = INSTANCE.toFlattened(matrices); return INSTANCE.toFlattened(matrices);
return ret;
} }
/** /**
@ -2194,27 +2100,7 @@ public class Nd4j {
* these ndarrays * these ndarrays
*/ */
public static INDArray toFlattened(char order, Collection<INDArray> matrices) { public static INDArray toFlattened(char order, Collection<INDArray> matrices) {
INDArray ret = INSTANCE.toFlattened(order, matrices); return INSTANCE.toFlattened(order, matrices);
return ret;
}
/**
* Create a long row vector of all of the given ndarrays
* @param matrices the matrices to create the flattened ndarray for
* @return the flattened representation of
* these ndarrays
*/
public static INDArray toFlattened(int length, Iterator<? extends INDArray>... matrices) {
INDArray ret = INSTANCE.toFlattened(length, matrices);
return ret;
}
/**
* Returns a column vector where each entry is the nth bilinear
* product of the nth slices of the two tensors.
*/
public static INDArray bilinearProducts(INDArray curr, INDArray in) {
return INSTANCE.bilinearProducts(curr, in);
} }
/** /**
@ -2243,32 +2129,28 @@ public class Nd4j {
* Create the identity ndarray * Create the identity ndarray
* *
* @param n the number for the identity * @param n the number for the identity
* @return * @return the identity array
*/ */
public static INDArray eye(long n) { public static INDArray eye(long n) {
INDArray ret = INSTANCE.eye(n); return INSTANCE.eye(n);
return ret;
} }
/** /**
* Rotate a matrix 90 degrees * Rotate a matrix 90 degrees
* *
* @param toRotate the matrix to rotate * @param toRotate the matrix to rotate
* @return the rotated matrix
*/ */
public static void rot90(INDArray toRotate) { public static void rot90(INDArray toRotate) {
INSTANCE.rot90(toRotate); INSTANCE.rot90(toRotate);
} }
/** /**
* Write NDArray to a text file * Write NDArray to a text file
* *
* @param filePath * @param filePath path to write to
* @param split the split separator, defaults to "," * @param split the split separator, defaults to ","
* @deprecated custom col separators are no longer supported; uses ","
* @param precision digits after the decimal point * @param precision digits after the decimal point
* @deprecated Precision is no longer used. * @deprecated Precision is no longer used. Split is no longer used.
* Defaults to scientific notation with 18 digits after the decimal * Defaults to scientific notation with 18 digits after the decimal
* Use {@link #writeTxt(INDArray, String)} * Use {@link #writeTxt(INDArray, String)}
*/ */
@ -2279,10 +2161,10 @@ public class Nd4j {
/** /**
* Write NDArray to a text file * Write NDArray to a text file
* *
* @param write * @param write array to write
* @param filePath * @param filePath path to write to
* @param precision * @param precision Precision is no longer used.
* @deprecated Precision is no longer used. * @deprecated
* Defaults to scientific notation with 18 digits after the decimal * Defaults to scientific notation with 18 digits after the decimal
* Use {@link #writeTxt(INDArray, String)} * Use {@link #writeTxt(INDArray, String)}
*/ */
@ -2293,9 +2175,9 @@ public class Nd4j {
/** /**
* Write NDArray to a text file * Write NDArray to a text file
* *
* @param write * @param write array to write
* @param filePath * @param filePath path to write to
* @param split * @param split the split separator, defaults to ","
* @deprecated custom col and higher dimension separators are no longer supported; uses "," * @deprecated custom col and higher dimension separators are no longer supported; uses ","
* Use {@link #writeTxt(INDArray, String)} * Use {@link #writeTxt(INDArray, String)}
*/ */
@ -2311,69 +2193,27 @@ public class Nd4j {
*/ */
public static void writeTxt(INDArray write, String filePath) { public static void writeTxt(INDArray write, String filePath) {
try { try {
String toWrite = writeStringForArray(write, "0.000000000000000000E0"); String toWrite = writeStringForArray(write);
FileUtils.writeStringToFile(new File(filePath), toWrite); FileUtils.writeStringToFile(new File(filePath), toWrite, (String)null, false);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException("Error writing output", e); throw new RuntimeException("Error writing output", e);
} }
} }
/** private static String writeStringForArray(INDArray write) {
* @deprecated custom col separators are no longer supported; uses ","
* @deprecated precision can no longer be specified. The array is written in scientific notation.
* @see #writeTxtString(INDArray, OutputStream)
*/
public static void writeTxtString(INDArray write, OutputStream os, String split, int precision) {
writeTxtString(write,os);
}
/**
* @deprecated precision can no longer be specified. The array is written in scientific notation.
* @see #writeTxtString(INDArray, OutputStream)
*/
@Deprecated
public static void writeTxtString(INDArray write, OutputStream os, int precision) {
writeTxtString(write,os);
}
/**
* @deprecated column separator can longer be specified; Uses ","
* @see #writeTxtString(INDArray, OutputStream)
*/
@Deprecated
public static void writeTxtString(INDArray write, OutputStream os, String split) {
writeTxtString(write, os);
}
/**
* Write ndarray as text to output stream
* @param write Array to write
* @param os stream to write too.
*/
public static void writeTxtString(INDArray write, OutputStream os) {
try {
// default format is "0.000000000000000000E0"
String toWrite = writeStringForArray(write, "0.000000000000000000E0");
os.write(toWrite.getBytes());
} catch (IOException e) {
throw new RuntimeException("Error writing output", e);
}
}
private static String writeStringForArray(INDArray write, String format) {
if(write.isView() || !Shape.hasDefaultStridesForShape(write)) if(write.isView() || !Shape.hasDefaultStridesForShape(write))
write = write.dup(); write = write.dup();
if (format.isEmpty()) format = "0.000000000000000000E0"; String format = "0.000000000000000000E0";
String lineOne = "{\n";
String lineTwo = "\"filefrom\": \"dl4j\",\n"; return new StringBuilder()
String lineThree = "\"ordering\": \"" + write.ordering() + "\",\n"; .append("{\n")
String lineFour = "\"shape\":\t" + java.util.Arrays.toString(write.shape()) + ",\n"; .append("\"filefrom\": \"dl4j\",\n")
String lineFive = "\"data\":\n"; .append( "\"ordering\": \"").append(write.ordering()).append("\",\n")
String fileData = new NDArrayStrings(",", format).format(write, false); .append("\"shape\":\t").append( java.util.Arrays.toString(write.shape())).append(",\n")
String fileEnd = "\n}\n"; .append("\"data\":\n")
String fileBegin = lineOne + lineTwo + lineThree + lineFour + lineFive; .append(new NDArrayStrings(",", format).format(write, false))
String fileContents = fileBegin + fileData + fileEnd; .append("\n}\n")
return fileContents; .toString();
} }