ImageRecordReader - NHWC support (#431)

* #8201 ImageRecordReader - NHWC support

Signed-off-by: Alex Black <blacka101@gmail.com>

* Check values - NCHW vs. NHWC

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-05-04 14:25:19 +10:00 committed by GitHub
parent 4dbdaca967
commit 092e6b9891
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 140 additions and 19 deletions

View File

@ -52,6 +52,7 @@ public class NDArrayRecordBatch extends AbstractWritableRecordBatch {
public NDArrayRecordBatch(@NonNull List<INDArray> arrays){ public NDArrayRecordBatch(@NonNull List<INDArray> arrays){
Preconditions.checkArgument(arrays.size() > 0, "Input list must not be empty"); Preconditions.checkArgument(arrays.size() > 0, "Input list must not be empty");
this.arrays = arrays; this.arrays = arrays;
this.size = arrays.get(0).size(0);
//Check that dimension 0 matches: //Check that dimension 0 matches:
if(arrays.size() > 1){ if(arrays.size() > 1){

View File

@ -77,6 +77,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
protected int patternPosition = 0; protected int patternPosition = 0;
@Getter @Setter @Getter @Setter
protected boolean logLabelCountOnInit = true; protected boolean logLabelCountOnInit = true;
@Getter @Setter
protected boolean nchw_channels_first = true;
public final static String HEIGHT = NAME_SPACE + ".height"; public final static String HEIGHT = NAME_SPACE + ".height";
public final static String WIDTH = NAME_SPACE + ".width"; public final static String WIDTH = NAME_SPACE + ".width";
@ -101,6 +103,11 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
protected BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator, protected BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) { PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
this(height, width, channels, true, labelGenerator, labelMultiGenerator, imageTransform);
}
protected BaseImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator,
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
this.height = height; this.height = height;
this.width = width; this.width = width;
this.channels = channels; this.channels = channels;
@ -108,6 +115,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
this.labelMultiGenerator = labelMultiGenerator; this.labelMultiGenerator = labelMultiGenerator;
this.imageTransform = imageTransform; this.imageTransform = imageTransform;
this.appendLabel = (labelGenerator != null || labelMultiGenerator != null); this.appendLabel = (labelGenerator != null || labelMultiGenerator != null);
this.nchw_channels_first = nchw_channels_first;
} }
protected boolean containsFormat(String format) { protected boolean containsFormat(String format) {
@ -237,9 +245,13 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
return next(); return next();
try { try {
invokeListeners(image); invokeListeners(image);
INDArray row = imageLoader.asMatrix(image); INDArray array = imageLoader.asMatrix(image);
Nd4j.getAffinityManager().ensureLocation(row, AffinityManager.Location.DEVICE); if(!nchw_channels_first){
ret = RecordConverter.toRecord(row); array = array.permute(0,2,3,1); //NCHW to NHWC
}
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.DEVICE);
ret = RecordConverter.toRecord(array);
if (appendLabel || writeLabel){ if (appendLabel || writeLabel){
if(labelMultiGenerator != null){ if(labelMultiGenerator != null){
ret.addAll(labelMultiGenerator.getLabels(image.getPath())); ret.addAll(labelMultiGenerator.getLabels(image.getPath()));
@ -286,7 +298,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
@Override @Override
public List<List<Writable>> next(int num) { public List<List<Writable>> next(int num) {
Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got " + num); Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got %s", num);
if (imageLoader == null) { if (imageLoader == null) {
imageLoader = new NativeImageLoader(height, width, channels, imageTransform); imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
@ -337,6 +349,9 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
if(!nchw_channels_first){
features = features.permute(0,2,3,1); //NCHW to NHWC
}
Nd4j.getAffinityManager().ensureLocation(features, AffinityManager.Location.DEVICE); Nd4j.getAffinityManager().ensureLocation(features, AffinityManager.Location.DEVICE);
@ -483,8 +498,10 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
if (imageLoader == null) { if (imageLoader == null) {
imageLoader = new NativeImageLoader(height, width, channels, imageTransform); imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
} }
INDArray row = imageLoader.asMatrix(dataInputStream); INDArray array = imageLoader.asMatrix(dataInputStream);
List<Writable> ret = RecordConverter.toRecord(row); if(!nchw_channels_first)
array = array.permute(0,2,3,1);
List<Writable> ret = RecordConverter.toRecord(array);
if (appendLabel) if (appendLabel)
ret.add(new IntWritable(labels.indexOf(getLabel(uri.getPath())))); ret.add(new IntWritable(labels.indexOf(getLabel(uri.getPath()))));
return ret; return ret;

View File

@ -34,47 +34,70 @@ import org.datavec.image.transform.ImageTransform;
public class ImageRecordReader extends BaseImageRecordReader { public class ImageRecordReader extends BaseImageRecordReader {
/** Loads images with height = 28, width = 28, and channels = 1, appending no labels. */ /** Loads images with height = 28, width = 28, and channels = 1, appending no labels.
* Output format is NCHW (channels first) - [numExamples, 1, 28, 28]*/
public ImageRecordReader() { public ImageRecordReader() {
super(); super();
} }
/** Loads images with given height, width, and channels, appending labels returned by the generator. */ /** Loads images with given height, width, and channels, appending labels returned by the generator.
* Output format is NCHW (channels first) - [numExamples, channels, height, width]
*/
public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator) { public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator) {
super(height, width, channels, labelGenerator); super(height, width, channels, labelGenerator);
} }
/** Loads images with given height, width, and channels, appending labels returned by the generator. */ /** Loads images with given height, width, and channels, appending labels returned by the generator.
* Output format is NCHW (channels first) - [numExamples, channels, height, width]
*/
public ImageRecordReader(long height, long width, long channels, PathMultiLabelGenerator labelGenerator) { public ImageRecordReader(long height, long width, long channels, PathMultiLabelGenerator labelGenerator) {
super(height, width, channels, labelGenerator); super(height, width, channels, labelGenerator);
} }
/** Loads images with given height, width, and channels, appending no labels. */ /** Loads images with given height, width, and channels, appending no labels - in NCHW (channels first) format */
public ImageRecordReader(long height, long width, long channels) { public ImageRecordReader(long height, long width, long channels) {
super(height, width, channels, (PathLabelGenerator) null); super(height, width, channels, (PathLabelGenerator) null);
} }
/** Loads images with given height, width, and channels, appending labels returned by the generator. */ /** Loads images with given height, width, and channels, appending no labels - in specified format<br>
* If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]<br>
* If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]<br>
*/
public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first) {
super(height, width, channels, nchw_channels_first, null, null, null);
}
/** Loads images with given height, width, and channels, appending labels returned by the generator.
* Output format is NCHW (channels first) - [numExamples, channels, height, width] */
public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator, public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
ImageTransform imageTransform) { ImageTransform imageTransform) {
super(height, width, channels, labelGenerator, imageTransform); super(height, width, channels, labelGenerator, imageTransform);
} }
/** Loads images with given height, width, and channels, appending no labels. */ /** Loads images with given height, width, and channels, appending labels returned by the generator.<br>
* If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]<br>
* If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]<br>
*/
public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator,
ImageTransform imageTransform) {
super(height, width, channels, nchw_channels_first, labelGenerator, null, imageTransform);
}
/** Loads images with given height, width, and channels, appending no labels.
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
public ImageRecordReader(long height, long width, long channels, ImageTransform imageTransform) { public ImageRecordReader(long height, long width, long channels, ImageTransform imageTransform) {
super(height, width, channels, null, imageTransform); super(height, width, channels, null, imageTransform);
} }
/** Loads images with given height, width, and channels, appending labels returned by the generator. */ /** Loads images with given height, width, and channels, appending labels returned by the generator
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
public ImageRecordReader(long height, long width, PathLabelGenerator labelGenerator) { public ImageRecordReader(long height, long width, PathLabelGenerator labelGenerator) {
super(height, width, 1, labelGenerator); super(height, width, 1, labelGenerator);
} }
/** Loads images with given height, width, and channels = 1, appending no labels. */ /** Loads images with given height, width, and channels = 1, appending no labels.
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
public ImageRecordReader(long height, long width) { public ImageRecordReader(long height, long width) {
super(height, width, 1, null, null); super(height, width, 1, null, null);
} }
} }

View File

@ -35,13 +35,13 @@ import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.*;
import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -467,5 +467,85 @@ public class TestImageRecordReader {
return count; return count;
} }
} }
@Test
public void testNCHW_NCHW() throws Exception {
//Idea: labels order should be consistent regardless of input file order
File f0 = testDir.newFolder();
Resources.copyDirectory("datavec-data-image/testimages/", f0);
FileSplit fs = new FileSplit(f0, new Random(12345));
assertEquals(6, fs.locations().length);
ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true);
nchw.initialize(fs);
ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false);
nhwc.initialize(fs);
while(nchw.hasNext()){
assertTrue(nhwc.hasNext());
List<Writable> l_nchw = nchw.next();
List<Writable> l_nhwc = nhwc.next();
INDArray a_nchw = ((NDArrayWritable)l_nchw.get(0)).get();
INDArray a_nhwc = ((NDArrayWritable)l_nhwc.get(0)).get();
assertArrayEquals(new long[]{1, 3, 32, 32}, a_nchw.shape());
assertArrayEquals(new long[]{1, 32, 32, 3}, a_nhwc.shape());
INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW
assertEquals(a_nchw, permuted);
}
//Test batch:
nchw.reset();
nhwc.reset();
int batchCount = 0;
while(nchw.hasNext()){
assertTrue(nhwc.hasNext());
batchCount++;
List<List<Writable>> l_nchw = nchw.next(3);
List<List<Writable>> l_nhwc = nhwc.next(3);
assertEquals(3, l_nchw.size());
assertEquals(3, l_nhwc.size());
NDArrayRecordBatch b_nchw = (NDArrayRecordBatch)l_nchw;
NDArrayRecordBatch b_nhwc = (NDArrayRecordBatch)l_nhwc;
INDArray a_nchw = b_nchw.getArrays().get(0);
INDArray a_nhwc = b_nhwc.getArrays().get(0);
assertArrayEquals(new long[]{3, 3, 32, 32}, a_nchw.shape());
assertArrayEquals(new long[]{3, 32, 32, 3}, a_nhwc.shape());
INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW
assertEquals(a_nchw, permuted);
}
assertEquals(2, batchCount);
//Test record(URI, DataInputStream)
URI u = fs.locations()[0];
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
List<Writable> l = nchw.record(u, dis);
INDArray arr = ((NDArrayWritable)l.get(0)).get();
assertArrayEquals(new long[]{1, 3, 32, 32}, arr.shape());
}
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
List<Writable> l = nhwc.record(u, dis);
INDArray arr = ((NDArrayWritable)l.get(0)).get();
assertArrayEquals(new long[]{1, 32, 32, 3}, arr.shape());
}
}
} }