Fix MNIST Fetcher to not re-allocate each batch (#200)
* don't allocate so many float arrays, use INDArrays instead Signed-off-by: Ryan Nett <rnett@skymind.io> * re-add pre-processing, better names Signed-off-by: Ryan Nett <rnett@skymind.io> * use float[][] pool to avoid extra ndarray creation Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
5cb6bebe4d
commit
378669cc10
|
@ -22,10 +22,12 @@ import org.deeplearning4j.base.MnistFetcher;
|
||||||
import org.deeplearning4j.common.resources.DL4JResources;
|
import org.deeplearning4j.common.resources.DL4JResources;
|
||||||
import org.deeplearning4j.common.resources.ResourceType;
|
import org.deeplearning4j.common.resources.ResourceType;
|
||||||
import org.deeplearning4j.datasets.mnist.MnistManager;
|
import org.deeplearning4j.datasets.mnist.MnistManager;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher;
|
import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.linalg.util.MathUtils;
|
import org.nd4j.linalg.util.MathUtils;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -167,14 +169,19 @@ public class MnistDataFetcher extends BaseDataFetcher {
|
||||||
this(true);
|
this(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private float[][] featureData = null;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void fetch(int numExamples) {
|
public void fetch(int numExamples) {
|
||||||
if (!hasMore()) {
|
if (!hasMore()) {
|
||||||
throw new IllegalStateException("Unable to get more; there are no more images");
|
throw new IllegalStateException("Unable to get more; there are no more images");
|
||||||
}
|
}
|
||||||
|
|
||||||
float[][] featureData = new float[numExamples][0];
|
INDArray labels = Nd4j.zeros(DataType.FLOAT, numExamples, numOutcomes);
|
||||||
float[][] labelData = new float[numExamples][0];
|
|
||||||
|
if(featureData == null || featureData.length < numExamples){
|
||||||
|
featureData = new float[numExamples][28*28];
|
||||||
|
}
|
||||||
|
|
||||||
int actualExamples = 0;
|
int actualExamples = 0;
|
||||||
byte[] working = null;
|
byte[] working = null;
|
||||||
|
@ -202,33 +209,33 @@ public class MnistDataFetcher extends BaseDataFetcher {
|
||||||
label--;
|
label--;
|
||||||
}
|
}
|
||||||
|
|
||||||
float[] featureVec = new float[img.length];
|
labels.put(actualExamples, label, 1.0f);
|
||||||
featureData[actualExamples] = featureVec;
|
|
||||||
labelData[actualExamples] = new float[numOutcomes];
|
|
||||||
labelData[actualExamples][label] = 1.0f;
|
|
||||||
|
|
||||||
for (int j = 0; j < img.length; j++) {
|
for(int j = 0 ; j < img.length ; j++) {
|
||||||
float v = ((int) img[j]) & 0xFF; //byte is loaded as signed -> convert to unsigned
|
featureData[actualExamples][j] = ((int) img[j]) & 0xFF;
|
||||||
if (binarize) {
|
|
||||||
if (v > 30.0f)
|
|
||||||
featureVec[j] = 1.0f;
|
|
||||||
else
|
|
||||||
featureVec[j] = 0.0f;
|
|
||||||
} else {
|
|
||||||
featureVec[j] = v / 255.0f;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
actualExamples++;
|
actualExamples++;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (actualExamples < numExamples) {
|
INDArray features;
|
||||||
featureData = Arrays.copyOfRange(featureData, 0, actualExamples);
|
|
||||||
labelData = Arrays.copyOfRange(labelData, 0, actualExamples);
|
if(featureData.length == actualExamples){
|
||||||
|
features = Nd4j.create(featureData);
|
||||||
|
} else {
|
||||||
|
features = Nd4j.create(Arrays.copyOfRange(featureData, 0, actualExamples));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (actualExamples < numExamples) {
|
||||||
|
labels = labels.get(NDArrayIndex.interval(0, actualExamples), NDArrayIndex.all());
|
||||||
|
}
|
||||||
|
|
||||||
|
if(binarize){
|
||||||
|
features = features.gt(30.0).castTo(DataType.FLOAT);
|
||||||
|
} else {
|
||||||
|
features.divi(255.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray features = Nd4j.create(featureData);
|
|
||||||
INDArray labels = Nd4j.create(labelData);
|
|
||||||
curr = new DataSet(features, labels);
|
curr = new DataSet(features, labels);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue