Fix MNIST Fetcher to not re-allocate each batch (#200)
* don't allocate so many float arrays, use INDArrays instead Signed-off-by: Ryan Nett <rnett@skymind.io> * re-add pre-processing, better names Signed-off-by: Ryan Nett <rnett@skymind.io> * use float[][] pool to avoid extra ndarray creation Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
5cb6bebe4d
commit
378669cc10
|
@ -22,10 +22,12 @@ import org.deeplearning4j.base.MnistFetcher;
|
|||
import org.deeplearning4j.common.resources.DL4JResources;
|
||||
import org.deeplearning4j.common.resources.ResourceType;
|
||||
import org.deeplearning4j.datasets.mnist.MnistManager;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.util.MathUtils;
|
||||
|
||||
import java.io.File;
|
||||
|
@ -167,14 +169,19 @@ public class MnistDataFetcher extends BaseDataFetcher {
|
|||
this(true);
|
||||
}
|
||||
|
||||
private float[][] featureData = null;
|
||||
|
||||
@Override
|
||||
public void fetch(int numExamples) {
|
||||
if (!hasMore()) {
|
||||
throw new IllegalStateException("Unable to get more; there are no more images");
|
||||
}
|
||||
|
||||
float[][] featureData = new float[numExamples][0];
|
||||
float[][] labelData = new float[numExamples][0];
|
||||
INDArray labels = Nd4j.zeros(DataType.FLOAT, numExamples, numOutcomes);
|
||||
|
||||
if(featureData == null || featureData.length < numExamples){
|
||||
featureData = new float[numExamples][28*28];
|
||||
}
|
||||
|
||||
int actualExamples = 0;
|
||||
byte[] working = null;
|
||||
|
@ -202,33 +209,33 @@ public class MnistDataFetcher extends BaseDataFetcher {
|
|||
label--;
|
||||
}
|
||||
|
||||
float[] featureVec = new float[img.length];
|
||||
featureData[actualExamples] = featureVec;
|
||||
labelData[actualExamples] = new float[numOutcomes];
|
||||
labelData[actualExamples][label] = 1.0f;
|
||||
labels.put(actualExamples, label, 1.0f);
|
||||
|
||||
for(int j = 0 ; j < img.length ; j++) {
|
||||
float v = ((int) img[j]) & 0xFF; //byte is loaded as signed -> convert to unsigned
|
||||
if (binarize) {
|
||||
if (v > 30.0f)
|
||||
featureVec[j] = 1.0f;
|
||||
else
|
||||
featureVec[j] = 0.0f;
|
||||
} else {
|
||||
featureVec[j] = v / 255.0f;
|
||||
}
|
||||
featureData[actualExamples][j] = ((int) img[j]) & 0xFF;
|
||||
}
|
||||
|
||||
actualExamples++;
|
||||
}
|
||||
|
||||
if (actualExamples < numExamples) {
|
||||
featureData = Arrays.copyOfRange(featureData, 0, actualExamples);
|
||||
labelData = Arrays.copyOfRange(labelData, 0, actualExamples);
|
||||
INDArray features;
|
||||
|
||||
if(featureData.length == actualExamples){
|
||||
features = Nd4j.create(featureData);
|
||||
} else {
|
||||
features = Nd4j.create(Arrays.copyOfRange(featureData, 0, actualExamples));
|
||||
}
|
||||
|
||||
if (actualExamples < numExamples) {
|
||||
labels = labels.get(NDArrayIndex.interval(0, actualExamples), NDArrayIndex.all());
|
||||
}
|
||||
|
||||
if(binarize){
|
||||
features = features.gt(30.0).castTo(DataType.FLOAT);
|
||||
} else {
|
||||
features.divi(255.0);
|
||||
}
|
||||
|
||||
INDArray features = Nd4j.create(featureData);
|
||||
INDArray labels = Nd4j.create(labelData);
|
||||
curr = new DataSet(features, labels);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue