Fix MNIST Fetcher to not re-allocate each batch (#200)

* don't allocate so many float arrays, use INDArrays instead

Signed-off-by: Ryan Nett <rnett@skymind.io>

* re-add pre-processing, better names

Signed-off-by: Ryan Nett <rnett@skymind.io>

* use float[][] pool to avoid extra ndarray creation

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-08-29 21:28:34 -07:00 committed by GitHub
parent 5cb6bebe4d
commit 378669cc10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 28 additions and 21 deletions

View File

@ -22,10 +22,12 @@ import org.deeplearning4j.base.MnistFetcher;
import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.common.resources.ResourceType;
import org.deeplearning4j.datasets.mnist.MnistManager;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.MathUtils;
import java.io.File;
@ -167,14 +169,19 @@ public class MnistDataFetcher extends BaseDataFetcher {
this(true);
}
private float[][] featureData = null;
@Override
public void fetch(int numExamples) {
if (!hasMore()) {
throw new IllegalStateException("Unable to get more; there are no more images");
}
float[][] featureData = new float[numExamples][0];
float[][] labelData = new float[numExamples][0];
INDArray labels = Nd4j.zeros(DataType.FLOAT, numExamples, numOutcomes);
if(featureData == null || featureData.length < numExamples){
featureData = new float[numExamples][28*28];
}
int actualExamples = 0;
byte[] working = null;
@ -202,33 +209,33 @@ public class MnistDataFetcher extends BaseDataFetcher {
label--;
}
float[] featureVec = new float[img.length];
featureData[actualExamples] = featureVec;
labelData[actualExamples] = new float[numOutcomes];
labelData[actualExamples][label] = 1.0f;
labels.put(actualExamples, label, 1.0f);
for(int j = 0 ; j < img.length ; j++) {
float v = ((int) img[j]) & 0xFF; //byte is loaded as signed -> convert to unsigned
if (binarize) {
if (v > 30.0f)
featureVec[j] = 1.0f;
else
featureVec[j] = 0.0f;
} else {
featureVec[j] = v / 255.0f;
}
featureData[actualExamples][j] = ((int) img[j]) & 0xFF;
}
actualExamples++;
}
if (actualExamples < numExamples) {
featureData = Arrays.copyOfRange(featureData, 0, actualExamples);
labelData = Arrays.copyOfRange(labelData, 0, actualExamples);
INDArray features;
if(featureData.length == actualExamples){
features = Nd4j.create(featureData);
} else {
features = Nd4j.create(Arrays.copyOfRange(featureData, 0, actualExamples));
}
if (actualExamples < numExamples) {
labels = labels.get(NDArrayIndex.interval(0, actualExamples), NDArrayIndex.all());
}
if(binarize){
features = features.gt(30.0).castTo(DataType.FLOAT);
} else {
features.divi(255.0);
}
INDArray features = Nd4j.create(featureData);
INDArray labels = Nd4j.create(labelData);
curr = new DataSet(features, labels);
}