diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/MnistDataFetcher.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/MnistDataFetcher.java index bd8726fed..3be8401d7 100755 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/MnistDataFetcher.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/MnistDataFetcher.java @@ -22,10 +22,12 @@ import org.deeplearning4j.base.MnistFetcher; import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.common.resources.ResourceType; import org.deeplearning4j.datasets.mnist.MnistManager; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.util.MathUtils; import java.io.File; @@ -167,14 +169,19 @@ public class MnistDataFetcher extends BaseDataFetcher { this(true); } + private float[][] featureData = null; + @Override public void fetch(int numExamples) { if (!hasMore()) { throw new IllegalStateException("Unable to get more; there are no more images"); } - float[][] featureData = new float[numExamples][0]; - float[][] labelData = new float[numExamples][0]; + INDArray labels = Nd4j.zeros(DataType.FLOAT, numExamples, numOutcomes); + + if(featureData == null || featureData.length < numExamples){ + featureData = new float[numExamples][28*28]; + } int actualExamples = 0; byte[] working = null; @@ -202,33 +209,33 @@ public class MnistDataFetcher extends BaseDataFetcher { label--; } - float[] featureVec = new float[img.length]; - featureData[actualExamples] = featureVec; - labelData[actualExamples] = new float[numOutcomes]; - labelData[actualExamples][label] = 1.0f; + labels.put(actualExamples, label, 1.0f); - for (int j = 0; j < img.length; j++) { - float v = ((int) img[j]) & 0xFF; //byte is loaded as signed -> convert to unsigned - if (binarize) { - if (v > 30.0f) - featureVec[j] = 1.0f; - else - featureVec[j] = 0.0f; - } else { - featureVec[j] = v / 255.0f; - } + for(int j = 0 ; j < img.length ; j++) { + featureData[actualExamples][j] = ((int) img[j]) & 0xFF; } actualExamples++; } - if (actualExamples < numExamples) { - featureData = Arrays.copyOfRange(featureData, 0, actualExamples); - labelData = Arrays.copyOfRange(labelData, 0, actualExamples); + INDArray features; + + if(featureData.length == actualExamples){ + features = Nd4j.create(featureData); + } else { + features = Nd4j.create(Arrays.copyOfRange(featureData, 0, actualExamples)); + } + + if (actualExamples < numExamples) { + labels = labels.get(NDArrayIndex.interval(0, actualExamples), NDArrayIndex.all()); + } + + if(binarize){ + features = features.gt(30.0).castTo(DataType.FLOAT); + } else { + features.divi(255.0); } - INDArray features = Nd4j.create(featureData); - INDArray labels = Nd4j.create(labelData); curr = new DataSet(features, labels); }