parent
edb71bf46f
commit
f8615e0ef0
|
@ -1,144 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ndarray;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.LineIterator;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Created by susaneraly on 6/16/16.
|
||||
*/
|
||||
@Deprecated
|
||||
public class NdArrayJSONReader {
|
||||
|
||||
public INDArray read(File jsonFile) {
|
||||
INDArray result = this.loadNative(jsonFile);
|
||||
if (result == null) {
|
||||
//Must write support for parsing/normal json parsing - which will be inefficient
|
||||
this.loadNonNative(jsonFile);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private INDArray loadNative(File jsonFile) {
|
||||
/*
|
||||
We could dump an ndarray to a file with the tostring (since that is valid json) and use put/get to parse it as json
|
||||
|
||||
But here we leverage our information of the tostring method to be more efficient
|
||||
With our current toString format we use tads along dimension (rank-1,rank-2) to write to the array in two dimensional chunks at a time.
|
||||
This is more efficient than setting each value at a time with putScalar.
|
||||
This also means we can read the file one line at a time instead of loading the whole thing into memory
|
||||
|
||||
Future work involves enhancing the write json method to provide more features to make the load more efficient
|
||||
*/
|
||||
int lineNum = 0;
|
||||
int rowNum = 0;
|
||||
int tensorNum = 0;
|
||||
char theOrder = 'c';
|
||||
int[] theShape = {1, 1};
|
||||
int rank = 0;
|
||||
double[][] subsetArr = {{0.0, 0.0}, {0.0, 0.0}};
|
||||
INDArray newArr = Nd4j.zeros(2, 2);
|
||||
try {
|
||||
LineIterator it = FileUtils.lineIterator(jsonFile);
|
||||
try {
|
||||
while (it.hasNext()) {
|
||||
String line = it.nextLine();
|
||||
lineNum++;
|
||||
line = line.replaceAll("\\s", "");
|
||||
if (line.equals("") || line.equals("}"))
|
||||
continue;
|
||||
// is it from dl4j?
|
||||
if (lineNum == 2) {
|
||||
String[] lineArr = line.split(":");
|
||||
String fileSource = lineArr[1].replaceAll("\\W", "");
|
||||
if (!fileSource.equals("dl4j"))
|
||||
return null;
|
||||
}
|
||||
// parse ordering
|
||||
if (lineNum == 3) {
|
||||
String[] lineArr = line.split(":");
|
||||
theOrder = lineArr[1].replace("\\W", "").charAt(0);
|
||||
continue;
|
||||
}
|
||||
// parse shape
|
||||
if (lineNum == 4) {
|
||||
String[] lineArr = line.split(":");
|
||||
String dropJsonComma = lineArr[1].split("]")[0];
|
||||
String[] shapeString = dropJsonComma.replace("[", "").split(",");
|
||||
rank = shapeString.length;
|
||||
theShape = new int[rank];
|
||||
for (int i = 0; i < rank; i++) {
|
||||
try {
|
||||
theShape[i] = Integer.parseInt(shapeString[i]);
|
||||
} catch (NumberFormatException nfe) {
|
||||
} ;
|
||||
}
|
||||
subsetArr = new double[theShape[rank - 2]][theShape[rank - 1]];
|
||||
newArr = Nd4j.zeros(theShape, theOrder);
|
||||
continue;
|
||||
}
|
||||
//parse data
|
||||
if (lineNum > 5) {
|
||||
String[] entries =
|
||||
line.replace("\\],", "").replaceAll("\\[", "").replaceAll("\\]", "").split(",");
|
||||
for (int i = 0; i < theShape[rank - 1]; i++) {
|
||||
try {
|
||||
subsetArr[rowNum][i] = Double.parseDouble(entries[i]);
|
||||
} catch (NumberFormatException nfe) {
|
||||
}
|
||||
}
|
||||
rowNum++;
|
||||
if (rowNum == theShape[rank - 2]) {
|
||||
INDArray subTensor = Nd4j.create(subsetArr);
|
||||
newArr.tensorAlongDimension(tensorNum, rank - 1, rank - 2).addi(subTensor);
|
||||
rowNum = 0;
|
||||
tensorNum++;
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
LineIterator.closeQuietly(it);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Error reading input", e);
|
||||
}
|
||||
return newArr;
|
||||
}
|
||||
|
||||
private INDArray loadNonNative(File jsonFile) {
|
||||
/* WIP
|
||||
JSONTokener tokener = new JSONTokener(new FileReader("test.json"));
|
||||
JSONObject obj = new JSONObject(tokener);
|
||||
JSONArray objArr = obj.optJSONArray("shape");
|
||||
int rank = objArr.length();
|
||||
int[] theShape = new int[rank];
|
||||
int rows = 1;
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
theShape[i] = objArr.optInt(i);
|
||||
if (i != objArr.length() - 1)
|
||||
rows *= theShape[i];
|
||||
}
|
||||
*/
|
||||
System.out.println("API_Error: Current support only for files written from dl4j");
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ndarray;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Created by susaneraly on 6/18/16.
|
||||
*/
|
||||
@Deprecated
|
||||
public class NdArrayJSONWriter {
|
||||
private NdArrayJSONWriter() {}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param thisnD
|
||||
* @param filePath
|
||||
*/
|
||||
public static void write(INDArray thisnD, String filePath) {
|
||||
//TO DO: Add precision support in toString
|
||||
//TO DO: Write to file one line at time
|
||||
String lineOne = "{\n";
|
||||
String lineTwo = "\"filefrom\": \"dl4j\",\n";
|
||||
String lineThree = "\"ordering\": \"" + thisnD.ordering() + "\",\n";
|
||||
String lineFour = "\"shape\":\t" + java.util.Arrays.toString(thisnD.shape()) + ",\n";
|
||||
String lineFive = "\"data\":\n";
|
||||
String fileData = thisnD.toString();
|
||||
String fileEnd = "\n}\n";
|
||||
|
||||
String fileBegin = lineOne + lineTwo + lineThree + lineFour + lineFive;
|
||||
try {
|
||||
FileUtils.writeStringToFile(new File(filePath), fileBegin + fileData + fileEnd);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Error writing output", e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -75,7 +75,6 @@ import org.nd4j.linalg.compression.CompressedDataBuffer;
|
|||
import org.nd4j.linalg.convolution.ConvolutionInstance;
|
||||
import org.nd4j.linalg.convolution.DefaultConvolutionInstance;
|
||||
import org.nd4j.linalg.env.EnvironmentalAction;
|
||||
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.exception.ND4JUnknownDataTypeException;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend.NoAvailableBackendException;
|
||||
|
@ -1073,7 +1072,7 @@ public class Nd4j {
|
|||
return ret;
|
||||
}
|
||||
|
||||
protected static Indexer getIndexerByType(Pointer pointer, DataType dataType) {
|
||||
private static Indexer getIndexerByType(Pointer pointer, DataType dataType) {
|
||||
switch (dataType) {
|
||||
case UINT64:
|
||||
case LONG:
|
||||
|
@ -1113,40 +1112,7 @@ public class Nd4j {
|
|||
* @return the created buffer
|
||||
*/
|
||||
public static DataBuffer createBuffer(@NonNull Pointer pointer, long length, @NonNull DataType dataType) {
|
||||
Pointer nPointer = null;
|
||||
//TODO: remove dupplicate code.
|
||||
switch (dataType) {
|
||||
case LONG:
|
||||
nPointer = new LongPointer(pointer);
|
||||
break;
|
||||
case INT:
|
||||
nPointer = new IntPointer(pointer);
|
||||
break;
|
||||
case SHORT:
|
||||
nPointer = new ShortPointer(pointer);
|
||||
break;
|
||||
case BYTE:
|
||||
nPointer = new BytePointer(pointer);
|
||||
break;
|
||||
case UBYTE:
|
||||
nPointer = new BytePointer(pointer);
|
||||
break;
|
||||
case BOOL:
|
||||
nPointer = new BooleanPointer(pointer);
|
||||
break;
|
||||
case FLOAT:
|
||||
nPointer = new FloatPointer(pointer);
|
||||
break;
|
||||
case HALF:
|
||||
nPointer = new ShortPointer(pointer);
|
||||
break;
|
||||
case DOUBLE:
|
||||
nPointer = new DoublePointer(pointer);
|
||||
break;
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unsupported data type: " + dataType);
|
||||
}
|
||||
|
||||
Pointer nPointer = getPointer(pointer, dataType);
|
||||
return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, dataType, length, getIndexerByType(nPointer, dataType));
|
||||
}
|
||||
|
||||
|
@ -1161,7 +1127,12 @@ public class Nd4j {
|
|||
* @return the created buffer
|
||||
*/
|
||||
public static DataBuffer createBuffer(@NonNull Pointer pointer, @NonNull Pointer devicePointer, long length, @NonNull DataType dataType) {
|
||||
Pointer nPointer = null;
|
||||
Pointer nPointer = getPointer(pointer, dataType);
|
||||
return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, devicePointer, dataType, length, getIndexerByType(nPointer, dataType));
|
||||
}
|
||||
|
||||
private static Pointer getPointer(@NonNull Pointer pointer, @NonNull DataType dataType ){
|
||||
Pointer nPointer;
|
||||
switch (dataType) {
|
||||
case UINT64:
|
||||
case LONG:
|
||||
|
@ -1198,7 +1169,7 @@ public class Nd4j {
|
|||
throw new UnsupportedOperationException("Unsupported data type: " + dataType);
|
||||
}
|
||||
|
||||
return DATA_BUFFER_FACTORY_INSTANCE.create(nPointer, devicePointer, dataType, length, getIndexerByType(nPointer, dataType));
|
||||
return nPointer;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1208,9 +1179,8 @@ public class Nd4j {
|
|||
* @return the created buffer
|
||||
*/
|
||||
public static DataBuffer createBuffer(float[] data, long offset) {
|
||||
val ndata = Arrays.copyOfRange(data, (int) offset, data.length);
|
||||
DataBuffer ret = createTypedBuffer(ndata, DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
return ret;
|
||||
return createTypedBuffer(Arrays.copyOfRange(data, (int) offset, data.length),
|
||||
DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1220,9 +1190,8 @@ public class Nd4j {
|
|||
* @return the created buffer
|
||||
*/
|
||||
public static DataBuffer createBuffer(double[] data, long offset) {
|
||||
val ndata = Arrays.copyOfRange(data, (int) offset, data.length);
|
||||
DataBuffer ret = createTypedBuffer(ndata, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
return ret;
|
||||
return createTypedBuffer(Arrays.copyOfRange(data, (int) offset, data.length),
|
||||
DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1294,7 +1263,7 @@ public class Nd4j {
|
|||
*
|
||||
* @param shape the shape of the buffer to create
|
||||
* @param type the opType to create
|
||||
* @return
|
||||
* @return the created buffer.
|
||||
*/
|
||||
public static DataBuffer createBufferDetached(int[] shape, DataType type) {
|
||||
return createBufferDetachedImpl( ArrayUtil.prodLong(shape), type);
|
||||
|
@ -1354,7 +1323,7 @@ public class Nd4j {
|
|||
* @param buffer the buffer to create from
|
||||
* @param type the opType of buffer to create
|
||||
* @param length the length of the buffer
|
||||
* @return
|
||||
* @return the created buffer
|
||||
*/
|
||||
public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length) {
|
||||
switch (type) {
|
||||
|
@ -1407,33 +1376,27 @@ public class Nd4j {
|
|||
* @return the created buffer
|
||||
*/
|
||||
public static DataBuffer createBuffer(long[] data) {
|
||||
DataBuffer ret;
|
||||
ret = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createLong(data) : DATA_BUFFER_FACTORY_INSTANCE.createLong(data, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
return ret;
|
||||
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createLong(data) : DATA_BUFFER_FACTORY_INSTANCE.createLong(data, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a buffer equal of length prod(shape). This method is NOT affected by workspaces
|
||||
*
|
||||
* @param data
|
||||
* @return
|
||||
* @param data the shape of the buffer to create
|
||||
* @return the created buffer
|
||||
*/
|
||||
public static DataBuffer createBufferDetached(int[] data) {
|
||||
DataBuffer ret;
|
||||
ret = DATA_BUFFER_FACTORY_INSTANCE.createInt(data);
|
||||
return ret;
|
||||
return DATA_BUFFER_FACTORY_INSTANCE.createInt(data);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a buffer equal of length prod(shape). This method is NOT affected by workspaces
|
||||
*
|
||||
* @param data
|
||||
* @return
|
||||
* @param data the shape of the buffer to create
|
||||
* @return the created buffer
|
||||
*/
|
||||
public static DataBuffer createBufferDetached(long[] data) {
|
||||
DataBuffer ret;
|
||||
ret = DATA_BUFFER_FACTORY_INSTANCE.createLong(data);
|
||||
return ret;
|
||||
return DATA_BUFFER_FACTORY_INSTANCE.createLong(data);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1464,9 +1427,7 @@ public class Nd4j {
|
|||
* See {@link #createBuffer(DataType dataType, long length, boolean initialize) with default datatype.
|
||||
*/
|
||||
public static DataBuffer createBuffer(long length, boolean initialize) {
|
||||
DataBuffer ret = createBuffer(Nd4j.dataType(), length, initialize);
|
||||
|
||||
return ret;
|
||||
return createBuffer(Nd4j.dataType(), length, initialize);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1524,6 +1485,11 @@ public class Nd4j {
|
|||
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createDouble(data) : DATA_BUFFER_FACTORY_INSTANCE.createDouble(data, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
// refactoring of duplicate code.
|
||||
private static DataBuffer getDataBuffer(int length, DataType dataType){
|
||||
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a buffer based on the data of the underlying java array with the specified type..
|
||||
* @param data underlying java array
|
||||
|
@ -1531,16 +1497,16 @@ public class Nd4j {
|
|||
* @return created buffer,
|
||||
*/
|
||||
public static DataBuffer createTypedBuffer(double[] data, DataType dataType) {
|
||||
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
DataBuffer buffer = getDataBuffer(data.length, dataType);
|
||||
buffer.setData(data);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #createTypedBuffer(float[], DataType)}
|
||||
* See {@link #createTypedBuffer(double[], DataType)}
|
||||
*/
|
||||
public static DataBuffer createTypedBuffer(float[] data, DataType dataType) {
|
||||
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
DataBuffer buffer = getDataBuffer(data.length, dataType);
|
||||
buffer.setData(data);
|
||||
return buffer;
|
||||
}
|
||||
|
@ -1549,7 +1515,7 @@ public class Nd4j {
|
|||
* See {@link #createTypedBuffer(float[], DataType)}
|
||||
*/
|
||||
public static DataBuffer createTypedBuffer(int[] data, DataType dataType) {
|
||||
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
DataBuffer buffer = getDataBuffer(data.length, dataType);
|
||||
buffer.setData(data);
|
||||
return buffer;
|
||||
}
|
||||
|
@ -1558,7 +1524,7 @@ public class Nd4j {
|
|||
* See {@link #createTypedBuffer(float[], DataType)}
|
||||
*/
|
||||
public static DataBuffer createTypedBuffer(long[] data, DataType dataType) {
|
||||
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
DataBuffer buffer = getDataBuffer(data.length, dataType);
|
||||
buffer.setData(data);
|
||||
return buffer;
|
||||
}
|
||||
|
@ -1567,7 +1533,7 @@ public class Nd4j {
|
|||
* See {@link #createTypedBuffer(float[], DataType)}
|
||||
*/
|
||||
public static DataBuffer createTypedBuffer(short[] data, DataType dataType) {
|
||||
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
DataBuffer buffer = getDataBuffer(data.length, dataType);
|
||||
buffer.setData(data);
|
||||
return buffer;
|
||||
}
|
||||
|
@ -1576,7 +1542,7 @@ public class Nd4j {
|
|||
* See {@link #createTypedBuffer(float[], DataType)}
|
||||
*/
|
||||
public static DataBuffer createTypedBuffer(byte[] data, DataType dataType) {
|
||||
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
DataBuffer buffer = getDataBuffer(data.length, dataType);
|
||||
buffer.setData(data);
|
||||
return buffer;
|
||||
}
|
||||
|
@ -1585,11 +1551,17 @@ public class Nd4j {
|
|||
* See {@link #createTypedBuffer(float[], DataType)}
|
||||
*/
|
||||
public static DataBuffer createTypedBuffer(boolean[] data, DataType dataType) {
|
||||
val buffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
DataBuffer buffer = getDataBuffer(data.length, dataType);
|
||||
buffer.setData(data);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
|
||||
// refactoring of duplicate code.
|
||||
private static DataBuffer getDataBuffer(int length, DataType dataType, MemoryWorkspace workspace){
|
||||
return workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, length, false, workspace);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a buffer based on the data of the underlying java array, specified type and workspace
|
||||
* @param data underlying java array
|
||||
|
@ -1598,61 +1570,18 @@ public class Nd4j {
|
|||
* @return created buffer,
|
||||
*/
|
||||
public static DataBuffer createTypedBuffer(double[] data, DataType dataType, MemoryWorkspace workspace) {
|
||||
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
|
||||
//val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
|
||||
DataBuffer buffer = getDataBuffer(data.length, dataType, workspace);
|
||||
buffer.setData(data);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
|
||||
* See {@link #createTypedBuffer(double[], DataType, MemoryWorkspace)}
|
||||
*/
|
||||
public static DataBuffer createTypedBuffer(float[] data, DataType dataType, MemoryWorkspace workspace) {
|
||||
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
|
||||
buffer.setData(data);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
|
||||
*/
|
||||
public static DataBuffer createTypedBuffer(int[] data, DataType dataType, MemoryWorkspace workspace) {
|
||||
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
|
||||
buffer.setData(data);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
|
||||
*/
|
||||
public static DataBuffer createTypedBuffer(long[] data, DataType dataType, MemoryWorkspace workspace) {
|
||||
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
|
||||
buffer.setData(data);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
|
||||
*/
|
||||
public static DataBuffer createTypedBuffer(short[] data, DataType dataType, MemoryWorkspace workspace) {
|
||||
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
|
||||
buffer.setData(data);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
|
||||
*/
|
||||
public static DataBuffer createTypedBuffer(byte[] data, DataType dataType, MemoryWorkspace workspace) {
|
||||
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
|
||||
buffer.setData(data);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #createTypedBuffer(float[], DataType, MemoryWorkspace)}
|
||||
*/
|
||||
public static DataBuffer createTypedBuffer(boolean[] data, DataType dataType, MemoryWorkspace workspace) {
|
||||
val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
|
||||
//val buffer = workspace == null ? DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false) : DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false, workspace);
|
||||
DataBuffer buffer = getDataBuffer(data.length, dataType, workspace);
|
||||
buffer.setData(data);
|
||||
return buffer;
|
||||
}
|
||||
|
@ -1661,7 +1590,7 @@ public class Nd4j {
|
|||
* Create am uninitialized buffer based on the data of the underlying java array and specified type.
|
||||
* @param data underlying java array
|
||||
* @param dataType specified type.
|
||||
* @return
|
||||
* @return the created buffer.
|
||||
*/
|
||||
public static DataBuffer createTypedBufferDetached(double[] data, DataType dataType) {
|
||||
val buffer = DATA_BUFFER_FACTORY_INSTANCE.create(dataType, data.length, false);
|
||||
|
@ -1731,14 +1660,6 @@ public class Nd4j {
|
|||
INSTANCE = factory;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the factory instance for sparse INDArray creation.
|
||||
* @param factory new INDArray factory
|
||||
*/
|
||||
public static void setSparseFactory(NDArrayFactory factory) {
|
||||
SPARSE_INSTANCE = factory;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the ordering of the ndarrays
|
||||
*
|
||||
|
@ -1810,15 +1731,6 @@ public class Nd4j {
|
|||
return SPARSE_BLAS_WRAPPER_INSTANCE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the global blas wrapper
|
||||
*
|
||||
* @param factory
|
||||
*/
|
||||
public static void setBlasWrapper(BlasWrapper factory) {
|
||||
BLAS_WRAPPER_INSTANCE = factory;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sort an ndarray along a particular dimension.<br>
|
||||
* Note that the input array is modified in-place.
|
||||
|
@ -1899,7 +1811,7 @@ public class Nd4j {
|
|||
*
|
||||
* @param ndarray array to sort
|
||||
* @param ascending true for ascending, false for descending
|
||||
* @return
|
||||
* @return the sorted ndarray
|
||||
*/
|
||||
public static INDArray sort(INDArray ndarray, boolean ascending) {
|
||||
return getNDArrayFactory().sort(ndarray, !ascending);
|
||||
|
@ -1931,20 +1843,18 @@ public class Nd4j {
|
|||
* @param in 2d array to sort
|
||||
* @param colIdx The column to sort on
|
||||
* @param ascending true if smallest-to-largest; false if largest-to-smallest
|
||||
* @return
|
||||
* @return the sorted ndarray
|
||||
*/
|
||||
@SuppressWarnings("Duplicates")
|
||||
public static INDArray sortRows(final INDArray in, final int colIdx, final boolean ascending) {
|
||||
if (in.rank() != 2)
|
||||
throw new IllegalArgumentException("Cannot sort rows on non-2d matrix");
|
||||
if (colIdx < 0 || colIdx >= in.columns())
|
||||
throw new IllegalArgumentException("Cannot sort on values in column " + colIdx + ", nCols=" + in.columns());
|
||||
|
||||
if (in.rows() > Integer.MAX_VALUE)
|
||||
throw new ND4JArraySizeException();
|
||||
|
||||
INDArray out = Nd4j.create(in.dataType(), in.shape());
|
||||
int nRows = (int) in.rows();
|
||||
ArrayList<Integer> list = new ArrayList<Integer>(nRows);
|
||||
int nRows = in.rows();
|
||||
ArrayList<Integer> list = new ArrayList<>(nRows);
|
||||
for (int i = 0; i < nRows; i++)
|
||||
list.add(i);
|
||||
Collections.sort(list, new Comparator<Integer>() {
|
||||
|
@ -1976,19 +1886,17 @@ public class Nd4j {
|
|||
* @param in 2d array to sort
|
||||
* @param rowIdx The row to sort on
|
||||
* @param ascending true if smallest-to-largest; false if largest-to-smallest
|
||||
* @return
|
||||
* @return the sorted array.
|
||||
*/
|
||||
@SuppressWarnings("Duplicates")
|
||||
public static INDArray sortColumns(final INDArray in, final int rowIdx, final boolean ascending) {
|
||||
if (in.rank() != 2)
|
||||
throw new IllegalArgumentException("Cannot sort columns on non-2d matrix");
|
||||
if (rowIdx < 0 || rowIdx >= in.rows())
|
||||
throw new IllegalArgumentException("Cannot sort on values in row " + rowIdx + ", nRows=" + in.rows());
|
||||
|
||||
if (in.columns() > Integer.MAX_VALUE)
|
||||
throw new ND4JArraySizeException();
|
||||
|
||||
INDArray out = Nd4j.create(in.shape());
|
||||
int nCols = (int) in.columns();
|
||||
int nCols = in.columns();
|
||||
ArrayList<Integer> list = new ArrayList<>(nCols);
|
||||
for (int i = 0; i < nCols; i++)
|
||||
list.add(i);
|
||||
|
@ -2041,7 +1949,7 @@ public class Nd4j {
|
|||
|
||||
if (dtype.isIntType()) {
|
||||
long upper = lower + num * step;
|
||||
return linspaceWithCustomOpByRange((long) lower, upper, num, (long) step, dtype);
|
||||
return linspaceWithCustomOpByRange( lower, upper, num, step, dtype);
|
||||
} else if (dtype.isFPType()) {
|
||||
return Nd4j.getExecutioner().exec(new Linspace((double) lower, num, (double)step, dtype));
|
||||
}
|
||||
|
@ -2062,7 +1970,6 @@ public class Nd4j {
|
|||
return linspace(lower, upper, num, Nd4j.dataType());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generate a linearly spaced vector
|
||||
*
|
||||
|
@ -2077,7 +1984,7 @@ public class Nd4j {
|
|||
return Nd4j.scalar(dtype, lower);
|
||||
}
|
||||
if (dtype.isIntType()) {
|
||||
return linspaceWithCustomOp((long)lower, (long)upper, (int)num, dtype);
|
||||
return linspaceWithCustomOp(lower, upper, (int)num, dtype);
|
||||
} else if (dtype.isFPType()) {
|
||||
return linspace((double) lower, (double)upper, (int) num, dtype);
|
||||
}
|
||||
|
@ -2182,8 +2089,7 @@ public class Nd4j {
|
|||
* these ndarrays
|
||||
*/
|
||||
public static INDArray toFlattened(Collection<INDArray> matrices) {
|
||||
INDArray ret = INSTANCE.toFlattened(matrices);
|
||||
return ret;
|
||||
return INSTANCE.toFlattened(matrices);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -2194,27 +2100,7 @@ public class Nd4j {
|
|||
* these ndarrays
|
||||
*/
|
||||
public static INDArray toFlattened(char order, Collection<INDArray> matrices) {
|
||||
INDArray ret = INSTANCE.toFlattened(order, matrices);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a long row vector of all of the given ndarrays
|
||||
* @param matrices the matrices to create the flattened ndarray for
|
||||
* @return the flattened representation of
|
||||
* these ndarrays
|
||||
*/
|
||||
public static INDArray toFlattened(int length, Iterator<? extends INDArray>... matrices) {
|
||||
INDArray ret = INSTANCE.toFlattened(length, matrices);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a column vector where each entry is the nth bilinear
|
||||
* product of the nth slices of the two tensors.
|
||||
*/
|
||||
public static INDArray bilinearProducts(INDArray curr, INDArray in) {
|
||||
return INSTANCE.bilinearProducts(curr, in);
|
||||
return INSTANCE.toFlattened(order, matrices);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -2243,32 +2129,28 @@ public class Nd4j {
|
|||
* Create the identity ndarray
|
||||
*
|
||||
* @param n the number for the identity
|
||||
* @return
|
||||
* @return the identity array
|
||||
*/
|
||||
public static INDArray eye(long n) {
|
||||
INDArray ret = INSTANCE.eye(n);
|
||||
return ret;
|
||||
return INSTANCE.eye(n);
|
||||
}
|
||||
|
||||
/**
|
||||
* Rotate a matrix 90 degrees
|
||||
*
|
||||
* @param toRotate the matrix to rotate
|
||||
* @return the rotated matrix
|
||||
*/
|
||||
public static void rot90(INDArray toRotate) {
|
||||
INSTANCE.rot90(toRotate);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Write NDArray to a text file
|
||||
*
|
||||
* @param filePath
|
||||
* @param filePath path to write to
|
||||
* @param split the split separator, defaults to ","
|
||||
* @deprecated custom col separators are no longer supported; uses ","
|
||||
* @param precision digits after the decimal point
|
||||
* @deprecated Precision is no longer used.
|
||||
* @deprecated Precision is no longer used. Split is no longer used.
|
||||
* Defaults to scientific notation with 18 digits after the decimal
|
||||
* Use {@link #writeTxt(INDArray, String)}
|
||||
*/
|
||||
|
@ -2279,10 +2161,10 @@ public class Nd4j {
|
|||
/**
|
||||
* Write NDArray to a text file
|
||||
*
|
||||
* @param write
|
||||
* @param filePath
|
||||
* @param precision
|
||||
* @deprecated Precision is no longer used.
|
||||
* @param write array to write
|
||||
* @param filePath path to write to
|
||||
* @param precision Precision is no longer used.
|
||||
* @deprecated
|
||||
* Defaults to scientific notation with 18 digits after the decimal
|
||||
* Use {@link #writeTxt(INDArray, String)}
|
||||
*/
|
||||
|
@ -2293,9 +2175,9 @@ public class Nd4j {
|
|||
/**
|
||||
* Write NDArray to a text file
|
||||
*
|
||||
* @param write
|
||||
* @param filePath
|
||||
* @param split
|
||||
* @param write array to write
|
||||
* @param filePath path to write to
|
||||
* @param split the split separator, defaults to ","
|
||||
* @deprecated custom col and higher dimension separators are no longer supported; uses ","
|
||||
* Use {@link #writeTxt(INDArray, String)}
|
||||
*/
|
||||
|
@ -2311,69 +2193,27 @@ public class Nd4j {
|
|||
*/
|
||||
public static void writeTxt(INDArray write, String filePath) {
|
||||
try {
|
||||
String toWrite = writeStringForArray(write, "0.000000000000000000E0");
|
||||
FileUtils.writeStringToFile(new File(filePath), toWrite);
|
||||
String toWrite = writeStringForArray(write);
|
||||
FileUtils.writeStringToFile(new File(filePath), toWrite, (String)null, false);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Error writing output", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated custom col separators are no longer supported; uses ","
|
||||
* @deprecated precision can no longer be specified. The array is written in scientific notation.
|
||||
* @see #writeTxtString(INDArray, OutputStream)
|
||||
*/
|
||||
public static void writeTxtString(INDArray write, OutputStream os, String split, int precision) {
|
||||
writeTxtString(write,os);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated precision can no longer be specified. The array is written in scientific notation.
|
||||
* @see #writeTxtString(INDArray, OutputStream)
|
||||
*/
|
||||
@Deprecated
|
||||
public static void writeTxtString(INDArray write, OutputStream os, int precision) {
|
||||
writeTxtString(write,os);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated column separator can longer be specified; Uses ","
|
||||
* @see #writeTxtString(INDArray, OutputStream)
|
||||
*/
|
||||
@Deprecated
|
||||
public static void writeTxtString(INDArray write, OutputStream os, String split) {
|
||||
writeTxtString(write, os);
|
||||
}
|
||||
|
||||
/**
|
||||
* Write ndarray as text to output stream
|
||||
* @param write Array to write
|
||||
* @param os stream to write too.
|
||||
*/
|
||||
public static void writeTxtString(INDArray write, OutputStream os) {
|
||||
try {
|
||||
// default format is "0.000000000000000000E0"
|
||||
String toWrite = writeStringForArray(write, "0.000000000000000000E0");
|
||||
os.write(toWrite.getBytes());
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Error writing output", e);
|
||||
}
|
||||
}
|
||||
|
||||
private static String writeStringForArray(INDArray write, String format) {
|
||||
private static String writeStringForArray(INDArray write) {
|
||||
if(write.isView() || !Shape.hasDefaultStridesForShape(write))
|
||||
write = write.dup();
|
||||
if (format.isEmpty()) format = "0.000000000000000000E0";
|
||||
String lineOne = "{\n";
|
||||
String lineTwo = "\"filefrom\": \"dl4j\",\n";
|
||||
String lineThree = "\"ordering\": \"" + write.ordering() + "\",\n";
|
||||
String lineFour = "\"shape\":\t" + java.util.Arrays.toString(write.shape()) + ",\n";
|
||||
String lineFive = "\"data\":\n";
|
||||
String fileData = new NDArrayStrings(",", format).format(write, false);
|
||||
String fileEnd = "\n}\n";
|
||||
String fileBegin = lineOne + lineTwo + lineThree + lineFour + lineFive;
|
||||
String fileContents = fileBegin + fileData + fileEnd;
|
||||
return fileContents;
|
||||
String format = "0.000000000000000000E0";
|
||||
|
||||
return new StringBuilder()
|
||||
.append("{\n")
|
||||
.append("\"filefrom\": \"dl4j\",\n")
|
||||
.append( "\"ordering\": \"").append(write.ordering()).append("\",\n")
|
||||
.append("\"shape\":\t").append( java.util.Arrays.toString(write.shape())).append(",\n")
|
||||
.append("\"data\":\n")
|
||||
.append(new NDArrayStrings(",", format).format(write, false))
|
||||
.append("\n}\n")
|
||||
.toString();
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue