Nd4j refactoring (#109)

* refactoring.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* review feedback

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-08-13 10:30:45 +09:00 committed by Alex Black
parent f49c4ea9d0
commit c29c011d1a
1 changed files with 18 additions and 64 deletions

View File

@ -2222,7 +2222,6 @@ public class Nd4j {
* Write an ndarray to a writer
* @param writer the writer to write to
* @param write the ndarray to write
* @throws IOException
*/
public static void write(OutputStream writer, INDArray write) throws IOException {
DataOutputStream stream = new DataOutputStream(writer);
@ -2230,14 +2229,12 @@ public class Nd4j {
stream.close();
}
/**
* Convert an ndarray to a byte array
* @param arr the array to convert
* @return the converted byte array
* @throws IOException
*/
public static byte[] toByteArray(INDArray arr) throws IOException {
public static byte[] toByteArray(@NonNull INDArray arr) throws IOException {
if (arr.length() * arr.data().getElementSize() > Integer.MAX_VALUE)
throw new ND4JIllegalStateException("");
@ -2252,15 +2249,13 @@ public class Nd4j {
* Read an ndarray from a byte array
* @param arr the array to read from
* @return the deserialized ndarray
* @throws IOException
*/
public static INDArray fromByteArray(byte[] arr) throws IOException {
public static INDArray fromByteArray(@NonNull byte[] arr) throws IOException {
ByteArrayInputStream bis = new ByteArrayInputStream(arr);
INDArray ret = read(bis);
return ret;
}
/**
* Read line via input streams
*
@ -2280,7 +2275,6 @@ public class Nd4j {
* @param split the split separator
* @param charset the charset
* @return the deserialized array.
* @throws IOException
*/
public static INDArray readNumpy(@NonNull DataType dataType, @NonNull InputStream filePath, @NonNull String split, @NonNull Charset charset) throws IOException {
BufferedReader reader = new BufferedReader(new InputStreamReader(filePath, charset));
@ -2449,7 +2443,7 @@ public class Nd4j {
}
//parse data
if (lineNum > 5) {
String[] entries = line.replace("\\],", "").replaceAll("\\]", "").replaceAll("\\[", "").split(sep);
String[] entries = line.replace("\\],", "").replaceAll("]", "").replaceAll("\\[", "").split(sep);
if (rank == 0) {
try {
newArr.addi((format.parse(entries[0])).doubleValue());
@ -2496,7 +2490,6 @@ public class Nd4j {
* @return NDArray
*/
public static INDArray readTxt(String filePath) {
String sep = ",";
File file = new File(filePath);
InputStream is = null;
try {
@ -2505,13 +2498,7 @@ public class Nd4j {
} catch (FileNotFoundException e) {
throw new RuntimeException(e);
} finally {
if (is != null) {
try {
is.close();
} catch (IOException e) {
e.printStackTrace();
}
}
IOUtils.closeQuietly(is);
}
}
@ -2552,9 +2539,9 @@ public class Nd4j {
*/
public static INDArray createArrayFromShapeBuffer(DataBuffer data, Pair<DataBuffer, long[]> shapeInfo) {
int rank = Shape.rank(shapeInfo.getFirst());
long offset = Shape.offset(shapeInfo.getFirst());
// removed offset parameter that called a deprecated method which always returns 0.
INDArray result = Nd4j.create(data, toIntArray(rank, Shape.shapeOf(shapeInfo.getFirst())),
toIntArray(rank, Shape.stride(shapeInfo.getFirst())), offset, Shape.order(shapeInfo.getFirst()));
toIntArray(rank, Shape.stride(shapeInfo.getFirst())), 0, Shape.order(shapeInfo.getFirst()));
if (data instanceof CompressedDataBuffer)
result.markAsCompressed(true);
@ -2566,7 +2553,6 @@ public class Nd4j {
*
* @param dis the data input stream to read from
* @return the ndarray
* @throws IOException
*/
public static INDArray read(DataInputStream dis) throws IOException {
val headerShape = BaseDataBuffer.readHeader(dis);
@ -2597,7 +2583,6 @@ public class Nd4j {
*
* @param arr the array to write
* @param dataOutputStream the data output stream to write to
* @throws IOException
*/
public static void write(INDArray arr, DataOutputStream dataOutputStream) throws IOException {
//BaseDataBuffer.write(...) doesn't know about strides etc, so dup (or equiv. strategy) is necessary here
@ -2614,7 +2599,6 @@ public class Nd4j {
* Save an ndarray to the given file
* @param arr the array to save
* @param saveTo the file to save to
* @throws IOException
*/
public static void saveBinary(INDArray arr, File saveTo) throws IOException {
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(saveTo));
@ -2629,7 +2613,6 @@ public class Nd4j {
* Read a binary ndarray from the given file
* @param read the nd array to read
* @return the loaded ndarray
* @throws IOException
*/
public static INDArray readBinary(File read) throws IOException {
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(read));
@ -2645,21 +2628,9 @@ public class Nd4j {
* @param arr the array to clear
*/
public static void clearNans(INDArray arr) {
//BooleanIndexing.applyWhere(arr, Conditions.isNan(), new Value(Nd4j.EPS_THRESHOLD));
getExecutioner().exec(new ReplaceNans(arr, Nd4j.EPS_THRESHOLD));
}
/**
* Reverses the passed in matrix such that m[0] becomes m[m.length - 1] etc
*
* @param reverse the matrix to reverse
* @return the reversed matrix
*/
public static INDArray rot(INDArray reverse) {
INDArray ret = INSTANCE.rot(reverse);
return ret;
}
/**
* Reverses the passed in matrix such that m[0] becomes m[m.length - 1] etc
*
@ -2667,8 +2638,6 @@ public class Nd4j {
* @return the reversed matrix
*/
public static INDArray reverse(INDArray reverse) {
//INDArray ret = INSTANCE.reverse(reverse);
//logCreationIfNecessary(ret);
return Nd4j.getExecutioner().exec(new OldReverse(reverse));
}
@ -2682,8 +2651,7 @@ public class Nd4j {
* @return the 1D range vector
*/
public static INDArray arange(double begin, double end, double step) {
INDArray ret = INSTANCE.arange(begin, end, step);
return ret;
return INSTANCE.arange(begin, end, step);
}
/**
@ -2693,8 +2661,7 @@ public class Nd4j {
* See {@link #arange(double, double, double)} with step size 1.
*/
public static INDArray arange(double begin, double end) {
INDArray ret = INSTANCE.arange(begin, end, 1);
return ret;
return INSTANCE.arange(begin, end, 1);
}
/**
@ -2717,27 +2684,6 @@ public class Nd4j {
INSTANCE.copy(a, b);
}
/**
* Creates a new matrix where the values of the given vector are the diagonal values of
* the matrix if a vector is passed in, if a matrix is returns the kth diagonal
* in the matrix
*
* @param x the diagonal values
* @param k the kth diagonal to get
* @return new matrix
*/
public static INDArray diag(INDArray x, int k) {
INDArray ret;
if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) {
ret = Nd4j.create(new long[]{x.length(), x.length()});
Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret}));
} else {
ret = Nd4j.createUninitialized(new long[]{Math.min(x.size(0), x.size(1))});
Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret));
}
return ret;
}
/**
* Creates a new matrix where the values of the given vector are the diagonal values of
* the matrix if a vector is passed in, if a matrix is returns the kth diagonal
@ -2747,7 +2693,15 @@ public class Nd4j {
* @return new matrix
*/
public static INDArray diag(INDArray x) {
return diag(x, 0);
INDArray ret;
if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) {
ret = Nd4j.create(x.dataType(), x.length(), x.length());
Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret}));
} else {
ret = Nd4j.createUninitialized(x.dataType(), Math.min(x.size(0), x.size(1)));
Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret));
}
return ret;
}
/**