Nd4j refactoring (#109)
* refactoring. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * review feedback Signed-off-by: Robert Altena <Rob@Ra-ai.com>master
parent
f49c4ea9d0
commit
c29c011d1a
|
@ -2222,7 +2222,6 @@ public class Nd4j {
|
||||||
* Write an ndarray to a writer
|
* Write an ndarray to a writer
|
||||||
* @param writer the writer to write to
|
* @param writer the writer to write to
|
||||||
* @param write the ndarray to write
|
* @param write the ndarray to write
|
||||||
* @throws IOException
|
|
||||||
*/
|
*/
|
||||||
public static void write(OutputStream writer, INDArray write) throws IOException {
|
public static void write(OutputStream writer, INDArray write) throws IOException {
|
||||||
DataOutputStream stream = new DataOutputStream(writer);
|
DataOutputStream stream = new DataOutputStream(writer);
|
||||||
|
@ -2230,14 +2229,12 @@ public class Nd4j {
|
||||||
stream.close();
|
stream.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert an ndarray to a byte array
|
* Convert an ndarray to a byte array
|
||||||
* @param arr the array to convert
|
* @param arr the array to convert
|
||||||
* @return the converted byte array
|
* @return the converted byte array
|
||||||
* @throws IOException
|
|
||||||
*/
|
*/
|
||||||
public static byte[] toByteArray(INDArray arr) throws IOException {
|
public static byte[] toByteArray(@NonNull INDArray arr) throws IOException {
|
||||||
if (arr.length() * arr.data().getElementSize() > Integer.MAX_VALUE)
|
if (arr.length() * arr.data().getElementSize() > Integer.MAX_VALUE)
|
||||||
throw new ND4JIllegalStateException("");
|
throw new ND4JIllegalStateException("");
|
||||||
|
|
||||||
|
@ -2252,15 +2249,13 @@ public class Nd4j {
|
||||||
* Read an ndarray from a byte array
|
* Read an ndarray from a byte array
|
||||||
* @param arr the array to read from
|
* @param arr the array to read from
|
||||||
* @return the deserialized ndarray
|
* @return the deserialized ndarray
|
||||||
* @throws IOException
|
|
||||||
*/
|
*/
|
||||||
public static INDArray fromByteArray(byte[] arr) throws IOException {
|
public static INDArray fromByteArray(@NonNull byte[] arr) throws IOException {
|
||||||
ByteArrayInputStream bis = new ByteArrayInputStream(arr);
|
ByteArrayInputStream bis = new ByteArrayInputStream(arr);
|
||||||
INDArray ret = read(bis);
|
INDArray ret = read(bis);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Read line via input streams
|
* Read line via input streams
|
||||||
*
|
*
|
||||||
|
@ -2280,7 +2275,6 @@ public class Nd4j {
|
||||||
* @param split the split separator
|
* @param split the split separator
|
||||||
* @param charset the charset
|
* @param charset the charset
|
||||||
* @return the deserialized array.
|
* @return the deserialized array.
|
||||||
* @throws IOException
|
|
||||||
*/
|
*/
|
||||||
public static INDArray readNumpy(@NonNull DataType dataType, @NonNull InputStream filePath, @NonNull String split, @NonNull Charset charset) throws IOException {
|
public static INDArray readNumpy(@NonNull DataType dataType, @NonNull InputStream filePath, @NonNull String split, @NonNull Charset charset) throws IOException {
|
||||||
BufferedReader reader = new BufferedReader(new InputStreamReader(filePath, charset));
|
BufferedReader reader = new BufferedReader(new InputStreamReader(filePath, charset));
|
||||||
|
@ -2449,7 +2443,7 @@ public class Nd4j {
|
||||||
}
|
}
|
||||||
//parse data
|
//parse data
|
||||||
if (lineNum > 5) {
|
if (lineNum > 5) {
|
||||||
String[] entries = line.replace("\\],", "").replaceAll("\\]", "").replaceAll("\\[", "").split(sep);
|
String[] entries = line.replace("\\],", "").replaceAll("]", "").replaceAll("\\[", "").split(sep);
|
||||||
if (rank == 0) {
|
if (rank == 0) {
|
||||||
try {
|
try {
|
||||||
newArr.addi((format.parse(entries[0])).doubleValue());
|
newArr.addi((format.parse(entries[0])).doubleValue());
|
||||||
|
@ -2496,7 +2490,6 @@ public class Nd4j {
|
||||||
* @return NDArray
|
* @return NDArray
|
||||||
*/
|
*/
|
||||||
public static INDArray readTxt(String filePath) {
|
public static INDArray readTxt(String filePath) {
|
||||||
String sep = ",";
|
|
||||||
File file = new File(filePath);
|
File file = new File(filePath);
|
||||||
InputStream is = null;
|
InputStream is = null;
|
||||||
try {
|
try {
|
||||||
|
@ -2505,13 +2498,7 @@ public class Nd4j {
|
||||||
} catch (FileNotFoundException e) {
|
} catch (FileNotFoundException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
} finally {
|
} finally {
|
||||||
if (is != null) {
|
IOUtils.closeQuietly(is);
|
||||||
try {
|
|
||||||
is.close();
|
|
||||||
} catch (IOException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2552,9 +2539,9 @@ public class Nd4j {
|
||||||
*/
|
*/
|
||||||
public static INDArray createArrayFromShapeBuffer(DataBuffer data, Pair<DataBuffer, long[]> shapeInfo) {
|
public static INDArray createArrayFromShapeBuffer(DataBuffer data, Pair<DataBuffer, long[]> shapeInfo) {
|
||||||
int rank = Shape.rank(shapeInfo.getFirst());
|
int rank = Shape.rank(shapeInfo.getFirst());
|
||||||
long offset = Shape.offset(shapeInfo.getFirst());
|
// removed offset parameter that called a deprecated method which always returns 0.
|
||||||
INDArray result = Nd4j.create(data, toIntArray(rank, Shape.shapeOf(shapeInfo.getFirst())),
|
INDArray result = Nd4j.create(data, toIntArray(rank, Shape.shapeOf(shapeInfo.getFirst())),
|
||||||
toIntArray(rank, Shape.stride(shapeInfo.getFirst())), offset, Shape.order(shapeInfo.getFirst()));
|
toIntArray(rank, Shape.stride(shapeInfo.getFirst())), 0, Shape.order(shapeInfo.getFirst()));
|
||||||
if (data instanceof CompressedDataBuffer)
|
if (data instanceof CompressedDataBuffer)
|
||||||
result.markAsCompressed(true);
|
result.markAsCompressed(true);
|
||||||
|
|
||||||
|
@ -2566,7 +2553,6 @@ public class Nd4j {
|
||||||
*
|
*
|
||||||
* @param dis the data input stream to read from
|
* @param dis the data input stream to read from
|
||||||
* @return the ndarray
|
* @return the ndarray
|
||||||
* @throws IOException
|
|
||||||
*/
|
*/
|
||||||
public static INDArray read(DataInputStream dis) throws IOException {
|
public static INDArray read(DataInputStream dis) throws IOException {
|
||||||
val headerShape = BaseDataBuffer.readHeader(dis);
|
val headerShape = BaseDataBuffer.readHeader(dis);
|
||||||
|
@ -2597,7 +2583,6 @@ public class Nd4j {
|
||||||
*
|
*
|
||||||
* @param arr the array to write
|
* @param arr the array to write
|
||||||
* @param dataOutputStream the data output stream to write to
|
* @param dataOutputStream the data output stream to write to
|
||||||
* @throws IOException
|
|
||||||
*/
|
*/
|
||||||
public static void write(INDArray arr, DataOutputStream dataOutputStream) throws IOException {
|
public static void write(INDArray arr, DataOutputStream dataOutputStream) throws IOException {
|
||||||
//BaseDataBuffer.write(...) doesn't know about strides etc, so dup (or equiv. strategy) is necessary here
|
//BaseDataBuffer.write(...) doesn't know about strides etc, so dup (or equiv. strategy) is necessary here
|
||||||
|
@ -2614,7 +2599,6 @@ public class Nd4j {
|
||||||
* Save an ndarray to the given file
|
* Save an ndarray to the given file
|
||||||
* @param arr the array to save
|
* @param arr the array to save
|
||||||
* @param saveTo the file to save to
|
* @param saveTo the file to save to
|
||||||
* @throws IOException
|
|
||||||
*/
|
*/
|
||||||
public static void saveBinary(INDArray arr, File saveTo) throws IOException {
|
public static void saveBinary(INDArray arr, File saveTo) throws IOException {
|
||||||
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(saveTo));
|
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(saveTo));
|
||||||
|
@ -2629,7 +2613,6 @@ public class Nd4j {
|
||||||
* Read a binary ndarray from the given file
|
* Read a binary ndarray from the given file
|
||||||
* @param read the nd array to read
|
* @param read the nd array to read
|
||||||
* @return the loaded ndarray
|
* @return the loaded ndarray
|
||||||
* @throws IOException
|
|
||||||
*/
|
*/
|
||||||
public static INDArray readBinary(File read) throws IOException {
|
public static INDArray readBinary(File read) throws IOException {
|
||||||
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(read));
|
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(read));
|
||||||
|
@ -2645,21 +2628,9 @@ public class Nd4j {
|
||||||
* @param arr the array to clear
|
* @param arr the array to clear
|
||||||
*/
|
*/
|
||||||
public static void clearNans(INDArray arr) {
|
public static void clearNans(INDArray arr) {
|
||||||
//BooleanIndexing.applyWhere(arr, Conditions.isNan(), new Value(Nd4j.EPS_THRESHOLD));
|
|
||||||
getExecutioner().exec(new ReplaceNans(arr, Nd4j.EPS_THRESHOLD));
|
getExecutioner().exec(new ReplaceNans(arr, Nd4j.EPS_THRESHOLD));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Reverses the passed in matrix such that m[0] becomes m[m.length - 1] etc
|
|
||||||
*
|
|
||||||
* @param reverse the matrix to reverse
|
|
||||||
* @return the reversed matrix
|
|
||||||
*/
|
|
||||||
public static INDArray rot(INDArray reverse) {
|
|
||||||
INDArray ret = INSTANCE.rot(reverse);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reverses the passed in matrix such that m[0] becomes m[m.length - 1] etc
|
* Reverses the passed in matrix such that m[0] becomes m[m.length - 1] etc
|
||||||
*
|
*
|
||||||
|
@ -2667,8 +2638,6 @@ public class Nd4j {
|
||||||
* @return the reversed matrix
|
* @return the reversed matrix
|
||||||
*/
|
*/
|
||||||
public static INDArray reverse(INDArray reverse) {
|
public static INDArray reverse(INDArray reverse) {
|
||||||
//INDArray ret = INSTANCE.reverse(reverse);
|
|
||||||
//logCreationIfNecessary(ret);
|
|
||||||
return Nd4j.getExecutioner().exec(new OldReverse(reverse));
|
return Nd4j.getExecutioner().exec(new OldReverse(reverse));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2682,8 +2651,7 @@ public class Nd4j {
|
||||||
* @return the 1D range vector
|
* @return the 1D range vector
|
||||||
*/
|
*/
|
||||||
public static INDArray arange(double begin, double end, double step) {
|
public static INDArray arange(double begin, double end, double step) {
|
||||||
INDArray ret = INSTANCE.arange(begin, end, step);
|
return INSTANCE.arange(begin, end, step);
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -2693,8 +2661,7 @@ public class Nd4j {
|
||||||
* See {@link #arange(double, double, double)} with step size 1.
|
* See {@link #arange(double, double, double)} with step size 1.
|
||||||
*/
|
*/
|
||||||
public static INDArray arange(double begin, double end) {
|
public static INDArray arange(double begin, double end) {
|
||||||
INDArray ret = INSTANCE.arange(begin, end, 1);
|
return INSTANCE.arange(begin, end, 1);
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -2717,27 +2684,6 @@ public class Nd4j {
|
||||||
INSTANCE.copy(a, b);
|
INSTANCE.copy(a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a new matrix where the values of the given vector are the diagonal values of
|
|
||||||
* the matrix if a vector is passed in, if a matrix is returns the kth diagonal
|
|
||||||
* in the matrix
|
|
||||||
*
|
|
||||||
* @param x the diagonal values
|
|
||||||
* @param k the kth diagonal to get
|
|
||||||
* @return new matrix
|
|
||||||
*/
|
|
||||||
public static INDArray diag(INDArray x, int k) {
|
|
||||||
INDArray ret;
|
|
||||||
if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) {
|
|
||||||
ret = Nd4j.create(new long[]{x.length(), x.length()});
|
|
||||||
Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret}));
|
|
||||||
} else {
|
|
||||||
ret = Nd4j.createUninitialized(new long[]{Math.min(x.size(0), x.size(1))});
|
|
||||||
Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret));
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new matrix where the values of the given vector are the diagonal values of
|
* Creates a new matrix where the values of the given vector are the diagonal values of
|
||||||
* the matrix if a vector is passed in, if a matrix is returns the kth diagonal
|
* the matrix if a vector is passed in, if a matrix is returns the kth diagonal
|
||||||
|
@ -2747,7 +2693,15 @@ public class Nd4j {
|
||||||
* @return new matrix
|
* @return new matrix
|
||||||
*/
|
*/
|
||||||
public static INDArray diag(INDArray x) {
|
public static INDArray diag(INDArray x) {
|
||||||
return diag(x, 0);
|
INDArray ret;
|
||||||
|
if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) {
|
||||||
|
ret = Nd4j.create(x.dataType(), x.length(), x.length());
|
||||||
|
Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret}));
|
||||||
|
} else {
|
||||||
|
ret = Nd4j.createUninitialized(x.dataType(), Math.min(x.size(0), x.size(1)));
|
||||||
|
Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret));
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
Loading…
Reference in New Issue