Nd4j refactoring (#109)
* refactoring. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * review feedback Signed-off-by: Robert Altena <Rob@Ra-ai.com>master
parent
f49c4ea9d0
commit
c29c011d1a
|
@ -2222,7 +2222,6 @@ public class Nd4j {
|
|||
* Write an ndarray to a writer
|
||||
* @param writer the writer to write to
|
||||
* @param write the ndarray to write
|
||||
* @throws IOException
|
||||
*/
|
||||
public static void write(OutputStream writer, INDArray write) throws IOException {
|
||||
DataOutputStream stream = new DataOutputStream(writer);
|
||||
|
@ -2230,14 +2229,12 @@ public class Nd4j {
|
|||
stream.close();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Convert an ndarray to a byte array
|
||||
* @param arr the array to convert
|
||||
* @return the converted byte array
|
||||
* @throws IOException
|
||||
*/
|
||||
public static byte[] toByteArray(INDArray arr) throws IOException {
|
||||
public static byte[] toByteArray(@NonNull INDArray arr) throws IOException {
|
||||
if (arr.length() * arr.data().getElementSize() > Integer.MAX_VALUE)
|
||||
throw new ND4JIllegalStateException("");
|
||||
|
||||
|
@ -2252,15 +2249,13 @@ public class Nd4j {
|
|||
* Read an ndarray from a byte array
|
||||
* @param arr the array to read from
|
||||
* @return the deserialized ndarray
|
||||
* @throws IOException
|
||||
*/
|
||||
public static INDArray fromByteArray(byte[] arr) throws IOException {
|
||||
public static INDArray fromByteArray(@NonNull byte[] arr) throws IOException {
|
||||
ByteArrayInputStream bis = new ByteArrayInputStream(arr);
|
||||
INDArray ret = read(bis);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Read line via input streams
|
||||
*
|
||||
|
@ -2280,7 +2275,6 @@ public class Nd4j {
|
|||
* @param split the split separator
|
||||
* @param charset the charset
|
||||
* @return the deserialized array.
|
||||
* @throws IOException
|
||||
*/
|
||||
public static INDArray readNumpy(@NonNull DataType dataType, @NonNull InputStream filePath, @NonNull String split, @NonNull Charset charset) throws IOException {
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(filePath, charset));
|
||||
|
@ -2449,7 +2443,7 @@ public class Nd4j {
|
|||
}
|
||||
//parse data
|
||||
if (lineNum > 5) {
|
||||
String[] entries = line.replace("\\],", "").replaceAll("\\]", "").replaceAll("\\[", "").split(sep);
|
||||
String[] entries = line.replace("\\],", "").replaceAll("]", "").replaceAll("\\[", "").split(sep);
|
||||
if (rank == 0) {
|
||||
try {
|
||||
newArr.addi((format.parse(entries[0])).doubleValue());
|
||||
|
@ -2496,7 +2490,6 @@ public class Nd4j {
|
|||
* @return NDArray
|
||||
*/
|
||||
public static INDArray readTxt(String filePath) {
|
||||
String sep = ",";
|
||||
File file = new File(filePath);
|
||||
InputStream is = null;
|
||||
try {
|
||||
|
@ -2505,13 +2498,7 @@ public class Nd4j {
|
|||
} catch (FileNotFoundException e) {
|
||||
throw new RuntimeException(e);
|
||||
} finally {
|
||||
if (is != null) {
|
||||
try {
|
||||
is.close();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
IOUtils.closeQuietly(is);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2552,9 +2539,9 @@ public class Nd4j {
|
|||
*/
|
||||
public static INDArray createArrayFromShapeBuffer(DataBuffer data, Pair<DataBuffer, long[]> shapeInfo) {
|
||||
int rank = Shape.rank(shapeInfo.getFirst());
|
||||
long offset = Shape.offset(shapeInfo.getFirst());
|
||||
// removed offset parameter that called a deprecated method which always returns 0.
|
||||
INDArray result = Nd4j.create(data, toIntArray(rank, Shape.shapeOf(shapeInfo.getFirst())),
|
||||
toIntArray(rank, Shape.stride(shapeInfo.getFirst())), offset, Shape.order(shapeInfo.getFirst()));
|
||||
toIntArray(rank, Shape.stride(shapeInfo.getFirst())), 0, Shape.order(shapeInfo.getFirst()));
|
||||
if (data instanceof CompressedDataBuffer)
|
||||
result.markAsCompressed(true);
|
||||
|
||||
|
@ -2566,7 +2553,6 @@ public class Nd4j {
|
|||
*
|
||||
* @param dis the data input stream to read from
|
||||
* @return the ndarray
|
||||
* @throws IOException
|
||||
*/
|
||||
public static INDArray read(DataInputStream dis) throws IOException {
|
||||
val headerShape = BaseDataBuffer.readHeader(dis);
|
||||
|
@ -2597,7 +2583,6 @@ public class Nd4j {
|
|||
*
|
||||
* @param arr the array to write
|
||||
* @param dataOutputStream the data output stream to write to
|
||||
* @throws IOException
|
||||
*/
|
||||
public static void write(INDArray arr, DataOutputStream dataOutputStream) throws IOException {
|
||||
//BaseDataBuffer.write(...) doesn't know about strides etc, so dup (or equiv. strategy) is necessary here
|
||||
|
@ -2614,7 +2599,6 @@ public class Nd4j {
|
|||
* Save an ndarray to the given file
|
||||
* @param arr the array to save
|
||||
* @param saveTo the file to save to
|
||||
* @throws IOException
|
||||
*/
|
||||
public static void saveBinary(INDArray arr, File saveTo) throws IOException {
|
||||
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(saveTo));
|
||||
|
@ -2629,7 +2613,6 @@ public class Nd4j {
|
|||
* Read a binary ndarray from the given file
|
||||
* @param read the nd array to read
|
||||
* @return the loaded ndarray
|
||||
* @throws IOException
|
||||
*/
|
||||
public static INDArray readBinary(File read) throws IOException {
|
||||
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(read));
|
||||
|
@ -2645,21 +2628,9 @@ public class Nd4j {
|
|||
* @param arr the array to clear
|
||||
*/
|
||||
public static void clearNans(INDArray arr) {
|
||||
//BooleanIndexing.applyWhere(arr, Conditions.isNan(), new Value(Nd4j.EPS_THRESHOLD));
|
||||
getExecutioner().exec(new ReplaceNans(arr, Nd4j.EPS_THRESHOLD));
|
||||
}
|
||||
|
||||
/**
|
||||
* Reverses the passed in matrix such that m[0] becomes m[m.length - 1] etc
|
||||
*
|
||||
* @param reverse the matrix to reverse
|
||||
* @return the reversed matrix
|
||||
*/
|
||||
public static INDArray rot(INDArray reverse) {
|
||||
INDArray ret = INSTANCE.rot(reverse);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Reverses the passed in matrix such that m[0] becomes m[m.length - 1] etc
|
||||
*
|
||||
|
@ -2667,8 +2638,6 @@ public class Nd4j {
|
|||
* @return the reversed matrix
|
||||
*/
|
||||
public static INDArray reverse(INDArray reverse) {
|
||||
//INDArray ret = INSTANCE.reverse(reverse);
|
||||
//logCreationIfNecessary(ret);
|
||||
return Nd4j.getExecutioner().exec(new OldReverse(reverse));
|
||||
}
|
||||
|
||||
|
@ -2682,8 +2651,7 @@ public class Nd4j {
|
|||
* @return the 1D range vector
|
||||
*/
|
||||
public static INDArray arange(double begin, double end, double step) {
|
||||
INDArray ret = INSTANCE.arange(begin, end, step);
|
||||
return ret;
|
||||
return INSTANCE.arange(begin, end, step);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -2693,8 +2661,7 @@ public class Nd4j {
|
|||
* See {@link #arange(double, double, double)} with step size 1.
|
||||
*/
|
||||
public static INDArray arange(double begin, double end) {
|
||||
INDArray ret = INSTANCE.arange(begin, end, 1);
|
||||
return ret;
|
||||
return INSTANCE.arange(begin, end, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -2717,27 +2684,6 @@ public class Nd4j {
|
|||
INSTANCE.copy(a, b);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new matrix where the values of the given vector are the diagonal values of
|
||||
* the matrix if a vector is passed in, if a matrix is returns the kth diagonal
|
||||
* in the matrix
|
||||
*
|
||||
* @param x the diagonal values
|
||||
* @param k the kth diagonal to get
|
||||
* @return new matrix
|
||||
*/
|
||||
public static INDArray diag(INDArray x, int k) {
|
||||
INDArray ret;
|
||||
if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) {
|
||||
ret = Nd4j.create(new long[]{x.length(), x.length()});
|
||||
Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret}));
|
||||
} else {
|
||||
ret = Nd4j.createUninitialized(new long[]{Math.min(x.size(0), x.size(1))});
|
||||
Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new matrix where the values of the given vector are the diagonal values of
|
||||
* the matrix if a vector is passed in, if a matrix is returns the kth diagonal
|
||||
|
@ -2747,7 +2693,15 @@ public class Nd4j {
|
|||
* @return new matrix
|
||||
*/
|
||||
public static INDArray diag(INDArray x) {
|
||||
return diag(x, 0);
|
||||
INDArray ret;
|
||||
if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) {
|
||||
ret = Nd4j.create(x.dataType(), x.length(), x.length());
|
||||
Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret}));
|
||||
} else {
|
||||
ret = Nd4j.createUninitialized(x.dataType(), Math.min(x.size(0), x.size(1)));
|
||||
Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
Loading…
Reference in New Issue