From 092e6b9891469334e6faa5ef3896856626d36359 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 4 May 2020 14:25:19 +1000 Subject: [PATCH] ImageRecordReader - NHWC support (#431) * #8201 ImageRecordReader - NHWC support Signed-off-by: Alex Black * Check values - NCHW vs. NHWC Signed-off-by: Alex Black --- .../writable/batch/NDArrayRecordBatch.java | 1 + .../recordreader/BaseImageRecordReader.java | 29 +++++-- .../image/recordreader/ImageRecordReader.java | 45 +++++++--- .../recordreader/TestImageRecordReader.java | 84 ++++++++++++++++++- 4 files changed, 140 insertions(+), 19 deletions(-) diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java index 0a5ddddb8..e9b78e390 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java @@ -52,6 +52,7 @@ public class NDArrayRecordBatch extends AbstractWritableRecordBatch { public NDArrayRecordBatch(@NonNull List arrays){ Preconditions.checkArgument(arrays.size() > 0, "Input list must not be empty"); this.arrays = arrays; + this.size = arrays.get(0).size(0); //Check that dimension 0 matches: if(arrays.size() > 1){ diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java index fb780ea74..d5400ee8e 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java @@ -77,6 +77,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { protected int patternPosition = 0; @Getter @Setter protected boolean logLabelCountOnInit = true; + @Getter @Setter + protected boolean nchw_channels_first = true; public final static String HEIGHT = NAME_SPACE + ".height"; public final static String WIDTH = NAME_SPACE + ".width"; @@ -101,6 +103,11 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { protected BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator, PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) { + this(height, width, channels, true, labelGenerator, labelMultiGenerator, imageTransform); + } + + protected BaseImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator, + PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) { this.height = height; this.width = width; this.channels = channels; @@ -108,6 +115,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { this.labelMultiGenerator = labelMultiGenerator; this.imageTransform = imageTransform; this.appendLabel = (labelGenerator != null || labelMultiGenerator != null); + this.nchw_channels_first = nchw_channels_first; } protected boolean containsFormat(String format) { @@ -237,9 +245,13 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { return next(); try { invokeListeners(image); - INDArray row = imageLoader.asMatrix(image); - Nd4j.getAffinityManager().ensureLocation(row, AffinityManager.Location.DEVICE); - ret = RecordConverter.toRecord(row); + INDArray array = imageLoader.asMatrix(image); + if(!nchw_channels_first){ + array = array.permute(0,2,3,1); //NCHW to NHWC + } + + Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.DEVICE); + ret = RecordConverter.toRecord(array); if (appendLabel || writeLabel){ if(labelMultiGenerator != null){ ret.addAll(labelMultiGenerator.getLabels(image.getPath())); @@ -286,7 +298,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { @Override public List> next(int num) { - Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got " + num); + Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got %s", num); if (imageLoader == null) { imageLoader = new NativeImageLoader(height, width, channels, imageTransform); @@ -337,6 +349,9 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { throw new RuntimeException(e); } } + if(!nchw_channels_first){ + features = features.permute(0,2,3,1); //NCHW to NHWC + } Nd4j.getAffinityManager().ensureLocation(features, AffinityManager.Location.DEVICE); @@ -483,8 +498,10 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { if (imageLoader == null) { imageLoader = new NativeImageLoader(height, width, channels, imageTransform); } - INDArray row = imageLoader.asMatrix(dataInputStream); - List ret = RecordConverter.toRecord(row); + INDArray array = imageLoader.asMatrix(dataInputStream); + if(!nchw_channels_first) + array = array.permute(0,2,3,1); + List ret = RecordConverter.toRecord(array); if (appendLabel) ret.add(new IntWritable(labels.indexOf(getLabel(uri.getPath())))); return ret; diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/ImageRecordReader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/ImageRecordReader.java index be7a6d8d9..f8e292c26 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/ImageRecordReader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/ImageRecordReader.java @@ -34,47 +34,70 @@ import org.datavec.image.transform.ImageTransform; public class ImageRecordReader extends BaseImageRecordReader { - /** Loads images with height = 28, width = 28, and channels = 1, appending no labels. */ + /** Loads images with height = 28, width = 28, and channels = 1, appending no labels. + * Output format is NCHW (channels first) - [numExamples, 1, 28, 28]*/ public ImageRecordReader() { super(); } - /** Loads images with given height, width, and channels, appending labels returned by the generator. */ + /** Loads images with given height, width, and channels, appending labels returned by the generator. + * Output format is NCHW (channels first) - [numExamples, channels, height, width] + */ public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator) { super(height, width, channels, labelGenerator); } - /** Loads images with given height, width, and channels, appending labels returned by the generator. */ + /** Loads images with given height, width, and channels, appending labels returned by the generator. + * Output format is NCHW (channels first) - [numExamples, channels, height, width] + */ public ImageRecordReader(long height, long width, long channels, PathMultiLabelGenerator labelGenerator) { super(height, width, channels, labelGenerator); } - /** Loads images with given height, width, and channels, appending no labels. */ + /** Loads images with given height, width, and channels, appending no labels - in NCHW (channels first) format */ public ImageRecordReader(long height, long width, long channels) { super(height, width, channels, (PathLabelGenerator) null); } - /** Loads images with given height, width, and channels, appending labels returned by the generator. */ + /** Loads images with given height, width, and channels, appending no labels - in specified format
+ * If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]
+ * If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]
+ */ + public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first) { + super(height, width, channels, nchw_channels_first, null, null, null); + } + + /** Loads images with given height, width, and channels, appending labels returned by the generator. + * Output format is NCHW (channels first) - [numExamples, channels, height, width] */ public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator, ImageTransform imageTransform) { super(height, width, channels, labelGenerator, imageTransform); } - /** Loads images with given height, width, and channels, appending no labels. */ + /** Loads images with given height, width, and channels, appending labels returned by the generator.
+ * If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]
+ * If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]
+ */ + public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator, + ImageTransform imageTransform) { + super(height, width, channels, nchw_channels_first, labelGenerator, null, imageTransform); + } + + /** Loads images with given height, width, and channels, appending no labels. + * Output format is NCHW (channels first) - [numExamples, channels, height, width]*/ public ImageRecordReader(long height, long width, long channels, ImageTransform imageTransform) { super(height, width, channels, null, imageTransform); } - /** Loads images with given height, width, and channels, appending labels returned by the generator. */ + /** Loads images with given height, width, and channels, appending labels returned by the generator + * Output format is NCHW (channels first) - [numExamples, channels, height, width]*/ public ImageRecordReader(long height, long width, PathLabelGenerator labelGenerator) { super(height, width, 1, labelGenerator); } - /** Loads images with given height, width, and channels = 1, appending no labels. */ + /** Loads images with given height, width, and channels = 1, appending no labels. + * Output format is NCHW (channels first) - [numExamples, channels, height, width]*/ public ImageRecordReader(long height, long width) { super(height, width, 1, null, null); } - - - } diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java index 80cb9b0af..fdcbe959a 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java @@ -35,13 +35,13 @@ import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.common.resources.Resources; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.io.ClassPathResource; -import java.io.File; -import java.io.IOException; +import java.io.*; import java.net.URI; import java.util.ArrayList; import java.util.Arrays; @@ -467,5 +467,85 @@ public class TestImageRecordReader { return count; } } + + + + @Test + public void testNCHW_NCHW() throws Exception { + //Idea: labels order should be consistent regardless of input file order + File f0 = testDir.newFolder(); + Resources.copyDirectory("datavec-data-image/testimages/", f0); + + FileSplit fs = new FileSplit(f0, new Random(12345)); + assertEquals(6, fs.locations().length); + + ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true); + nchw.initialize(fs); + + ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false); + nhwc.initialize(fs); + + while(nchw.hasNext()){ + assertTrue(nhwc.hasNext()); + + List l_nchw = nchw.next(); + List l_nhwc = nhwc.next(); + + INDArray a_nchw = ((NDArrayWritable)l_nchw.get(0)).get(); + INDArray a_nhwc = ((NDArrayWritable)l_nhwc.get(0)).get(); + + assertArrayEquals(new long[]{1, 3, 32, 32}, a_nchw.shape()); + assertArrayEquals(new long[]{1, 32, 32, 3}, a_nhwc.shape()); + + INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW + assertEquals(a_nchw, permuted); + } + + + //Test batch: + nchw.reset(); + nhwc.reset(); + + int batchCount = 0; + while(nchw.hasNext()){ + assertTrue(nhwc.hasNext()); + batchCount++; + + List> l_nchw = nchw.next(3); + List> l_nhwc = nhwc.next(3); + assertEquals(3, l_nchw.size()); + assertEquals(3, l_nhwc.size()); + + NDArrayRecordBatch b_nchw = (NDArrayRecordBatch)l_nchw; + NDArrayRecordBatch b_nhwc = (NDArrayRecordBatch)l_nhwc; + + INDArray a_nchw = b_nchw.getArrays().get(0); + INDArray a_nhwc = b_nhwc.getArrays().get(0); + + assertArrayEquals(new long[]{3, 3, 32, 32}, a_nchw.shape()); + assertArrayEquals(new long[]{3, 32, 32, 3}, a_nhwc.shape()); + + INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW + assertEquals(a_nchw, permuted); + } + assertEquals(2, batchCount); + + + //Test record(URI, DataInputStream) + + URI u = fs.locations()[0]; + + try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) { + List l = nchw.record(u, dis); + INDArray arr = ((NDArrayWritable)l.get(0)).get(); + assertArrayEquals(new long[]{1, 3, 32, 32}, arr.shape()); + } + + try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) { + List l = nhwc.record(u, dis); + INDArray arr = ((NDArrayWritable)l.get(0)).get(); + assertArrayEquals(new long[]{1, 32, 32, 3}, arr.shape()); + } + } }