Fix MNIST Fetcher to not re-allocate each batch (#200)

* don't allocate so many float arrays, use INDArrays instead

Signed-off-by: Ryan Nett <rnett@skymind.io>

* re-add pre-processing, better names

Signed-off-by: Ryan Nett <rnett@skymind.io>

* use float[][] pool to avoid extra ndarray creation

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-08-29 21:28:34 -07:00 committed by GitHub
parent 5cb6bebe4d
commit 378669cc10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 28 additions and 21 deletions

View File

@ -22,10 +22,12 @@ import org.deeplearning4j.base.MnistFetcher;
import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.common.resources.ResourceType; import org.deeplearning4j.common.resources.ResourceType;
import org.deeplearning4j.datasets.mnist.MnistManager; import org.deeplearning4j.datasets.mnist.MnistManager;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher; import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.MathUtils; import org.nd4j.linalg.util.MathUtils;
import java.io.File; import java.io.File;
@ -167,14 +169,19 @@ public class MnistDataFetcher extends BaseDataFetcher {
this(true); this(true);
} }
private float[][] featureData = null;
@Override @Override
public void fetch(int numExamples) { public void fetch(int numExamples) {
if (!hasMore()) { if (!hasMore()) {
throw new IllegalStateException("Unable to get more; there are no more images"); throw new IllegalStateException("Unable to get more; there are no more images");
} }
float[][] featureData = new float[numExamples][0]; INDArray labels = Nd4j.zeros(DataType.FLOAT, numExamples, numOutcomes);
float[][] labelData = new float[numExamples][0];
if(featureData == null || featureData.length < numExamples){
featureData = new float[numExamples][28*28];
}
int actualExamples = 0; int actualExamples = 0;
byte[] working = null; byte[] working = null;
@ -202,33 +209,33 @@ public class MnistDataFetcher extends BaseDataFetcher {
label--; label--;
} }
float[] featureVec = new float[img.length]; labels.put(actualExamples, label, 1.0f);
featureData[actualExamples] = featureVec;
labelData[actualExamples] = new float[numOutcomes];
labelData[actualExamples][label] = 1.0f;
for (int j = 0; j < img.length; j++) { for(int j = 0 ; j < img.length ; j++) {
float v = ((int) img[j]) & 0xFF; //byte is loaded as signed -> convert to unsigned featureData[actualExamples][j] = ((int) img[j]) & 0xFF;
if (binarize) {
if (v > 30.0f)
featureVec[j] = 1.0f;
else
featureVec[j] = 0.0f;
} else {
featureVec[j] = v / 255.0f;
}
} }
actualExamples++; actualExamples++;
} }
if (actualExamples < numExamples) { INDArray features;
featureData = Arrays.copyOfRange(featureData, 0, actualExamples);
labelData = Arrays.copyOfRange(labelData, 0, actualExamples); if(featureData.length == actualExamples){
features = Nd4j.create(featureData);
} else {
features = Nd4j.create(Arrays.copyOfRange(featureData, 0, actualExamples));
}
if (actualExamples < numExamples) {
labels = labels.get(NDArrayIndex.interval(0, actualExamples), NDArrayIndex.all());
}
if(binarize){
features = features.gt(30.0).castTo(DataType.FLOAT);
} else {
features.divi(255.0);
} }
INDArray features = Nd4j.create(featureData);
INDArray labels = Nd4j.create(labelData);
curr = new DataSet(features, labels); curr = new DataSet(features, labels);
} }