Small optimization to Nd4j.readNumpy (#183)

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2019-08-27 23:27:41 +10:00 committed by GitHub
parent 7f0c660d8b
commit 9d325ad070
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 7 deletions

View File

@ -2265,15 +2265,12 @@ public class Nd4j {
Preconditions.checkState(data.length == numColumns,
"Data has inconsistent number of columns: data length %s, numColumns %s", data.length, numColumns);
data2.add(readSplit(data));
}
ret = Nd4j.create(dataType, data2.size(), numColumns);
for (int i = 0; i < data2.size(); i++) {
float[] row = data2.get(i);
INDArray arr = Nd4j.create(row, new long[]{1, row.length}, dataType);
ret.putRow(i, arr);
float[][] fArr = new float[data2.size()][0];
for(int i=0; i<data2.size(); i++ ){
fArr[i] = data2.get(i);
}
ret = Nd4j.createFromArray(fArr).castTo(dataType);
return ret;
}