ImageRecordReader - NHWC support (#431)

* #8201 ImageRecordReader - NHWC support

Signed-off-by: Alex Black <blacka101@gmail.com>

* Check values - NCHW vs. NHWC

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-05-04 14:25:19 +10:00 committed by GitHub
parent 4dbdaca967
commit 092e6b9891
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 140 additions and 19 deletions

View File

@ -52,6 +52,7 @@ public class NDArrayRecordBatch extends AbstractWritableRecordBatch {
public NDArrayRecordBatch(@NonNull List<INDArray> arrays){
Preconditions.checkArgument(arrays.size() > 0, "Input list must not be empty");
this.arrays = arrays;
this.size = arrays.get(0).size(0);
//Check that dimension 0 matches:
if(arrays.size() > 1){

View File

@ -77,6 +77,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
protected int patternPosition = 0;
@Getter @Setter
protected boolean logLabelCountOnInit = true;
@Getter @Setter
protected boolean nchw_channels_first = true;
public final static String HEIGHT = NAME_SPACE + ".height";
public final static String WIDTH = NAME_SPACE + ".width";
@ -101,6 +103,11 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
protected BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
this(height, width, channels, true, labelGenerator, labelMultiGenerator, imageTransform);
}
protected BaseImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator,
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
this.height = height;
this.width = width;
this.channels = channels;
@ -108,6 +115,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
this.labelMultiGenerator = labelMultiGenerator;
this.imageTransform = imageTransform;
this.appendLabel = (labelGenerator != null || labelMultiGenerator != null);
this.nchw_channels_first = nchw_channels_first;
}
protected boolean containsFormat(String format) {
@ -237,9 +245,13 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
return next();
try {
invokeListeners(image);
INDArray row = imageLoader.asMatrix(image);
Nd4j.getAffinityManager().ensureLocation(row, AffinityManager.Location.DEVICE);
ret = RecordConverter.toRecord(row);
INDArray array = imageLoader.asMatrix(image);
if(!nchw_channels_first){
array = array.permute(0,2,3,1); //NCHW to NHWC
}
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.DEVICE);
ret = RecordConverter.toRecord(array);
if (appendLabel || writeLabel){
if(labelMultiGenerator != null){
ret.addAll(labelMultiGenerator.getLabels(image.getPath()));
@ -286,7 +298,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
@Override
public List<List<Writable>> next(int num) {
Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got " + num);
Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got %s", num);
if (imageLoader == null) {
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
@ -337,6 +349,9 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
throw new RuntimeException(e);
}
}
if(!nchw_channels_first){
features = features.permute(0,2,3,1); //NCHW to NHWC
}
Nd4j.getAffinityManager().ensureLocation(features, AffinityManager.Location.DEVICE);
@ -483,8 +498,10 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
if (imageLoader == null) {
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
}
INDArray row = imageLoader.asMatrix(dataInputStream);
List<Writable> ret = RecordConverter.toRecord(row);
INDArray array = imageLoader.asMatrix(dataInputStream);
if(!nchw_channels_first)
array = array.permute(0,2,3,1);
List<Writable> ret = RecordConverter.toRecord(array);
if (appendLabel)
ret.add(new IntWritable(labels.indexOf(getLabel(uri.getPath()))));
return ret;

View File

@ -34,47 +34,70 @@ import org.datavec.image.transform.ImageTransform;
public class ImageRecordReader extends BaseImageRecordReader {
/** Loads images with height = 28, width = 28, and channels = 1, appending no labels. */
/** Loads images with height = 28, width = 28, and channels = 1, appending no labels.
* Output format is NCHW (channels first) - [numExamples, 1, 28, 28]*/
public ImageRecordReader() {
super();
}
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
/** Loads images with given height, width, and channels, appending labels returned by the generator.
* Output format is NCHW (channels first) - [numExamples, channels, height, width]
*/
public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator) {
super(height, width, channels, labelGenerator);
}
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
/** Loads images with given height, width, and channels, appending labels returned by the generator.
* Output format is NCHW (channels first) - [numExamples, channels, height, width]
*/
public ImageRecordReader(long height, long width, long channels, PathMultiLabelGenerator labelGenerator) {
super(height, width, channels, labelGenerator);
}
/** Loads images with given height, width, and channels, appending no labels. */
/** Loads images with given height, width, and channels, appending no labels - in NCHW (channels first) format */
public ImageRecordReader(long height, long width, long channels) {
super(height, width, channels, (PathLabelGenerator) null);
}
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
/** Loads images with given height, width, and channels, appending no labels - in specified format<br>
* If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]<br>
* If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]<br>
*/
public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first) {
super(height, width, channels, nchw_channels_first, null, null, null);
}
/** Loads images with given height, width, and channels, appending labels returned by the generator.
* Output format is NCHW (channels first) - [numExamples, channels, height, width] */
public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
ImageTransform imageTransform) {
super(height, width, channels, labelGenerator, imageTransform);
}
/** Loads images with given height, width, and channels, appending no labels. */
/** Loads images with given height, width, and channels, appending labels returned by the generator.<br>
* If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]<br>
* If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]<br>
*/
public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator,
ImageTransform imageTransform) {
super(height, width, channels, nchw_channels_first, labelGenerator, null, imageTransform);
}
/** Loads images with given height, width, and channels, appending no labels.
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
public ImageRecordReader(long height, long width, long channels, ImageTransform imageTransform) {
super(height, width, channels, null, imageTransform);
}
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
/** Loads images with given height, width, and channels, appending labels returned by the generator
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
public ImageRecordReader(long height, long width, PathLabelGenerator labelGenerator) {
super(height, width, 1, labelGenerator);
}
/** Loads images with given height, width, and channels = 1, appending no labels. */
/** Loads images with given height, width, and channels = 1, appending no labels.
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
public ImageRecordReader(long height, long width) {
super(height, width, 1, null, null);
}
}

View File

@ -35,13 +35,13 @@ import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.io.IOException;
import java.io.*;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
@ -467,5 +467,85 @@ public class TestImageRecordReader {
return count;
}
}
@Test
public void testNCHW_NCHW() throws Exception {
//Idea: labels order should be consistent regardless of input file order
File f0 = testDir.newFolder();
Resources.copyDirectory("datavec-data-image/testimages/", f0);
FileSplit fs = new FileSplit(f0, new Random(12345));
assertEquals(6, fs.locations().length);
ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true);
nchw.initialize(fs);
ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false);
nhwc.initialize(fs);
while(nchw.hasNext()){
assertTrue(nhwc.hasNext());
List<Writable> l_nchw = nchw.next();
List<Writable> l_nhwc = nhwc.next();
INDArray a_nchw = ((NDArrayWritable)l_nchw.get(0)).get();
INDArray a_nhwc = ((NDArrayWritable)l_nhwc.get(0)).get();
assertArrayEquals(new long[]{1, 3, 32, 32}, a_nchw.shape());
assertArrayEquals(new long[]{1, 32, 32, 3}, a_nhwc.shape());
INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW
assertEquals(a_nchw, permuted);
}
//Test batch:
nchw.reset();
nhwc.reset();
int batchCount = 0;
while(nchw.hasNext()){
assertTrue(nhwc.hasNext());
batchCount++;
List<List<Writable>> l_nchw = nchw.next(3);
List<List<Writable>> l_nhwc = nhwc.next(3);
assertEquals(3, l_nchw.size());
assertEquals(3, l_nhwc.size());
NDArrayRecordBatch b_nchw = (NDArrayRecordBatch)l_nchw;
NDArrayRecordBatch b_nhwc = (NDArrayRecordBatch)l_nhwc;
INDArray a_nchw = b_nchw.getArrays().get(0);
INDArray a_nhwc = b_nhwc.getArrays().get(0);
assertArrayEquals(new long[]{3, 3, 32, 32}, a_nchw.shape());
assertArrayEquals(new long[]{3, 32, 32, 3}, a_nhwc.shape());
INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW
assertEquals(a_nchw, permuted);
}
assertEquals(2, batchCount);
//Test record(URI, DataInputStream)
URI u = fs.locations()[0];
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
List<Writable> l = nchw.record(u, dis);
INDArray arr = ((NDArrayWritable)l.get(0)).get();
assertArrayEquals(new long[]{1, 3, 32, 32}, arr.shape());
}
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
List<Writable> l = nhwc.record(u, dis);
INDArray arr = ((NDArrayWritable)l.get(0)).get();
assertArrayEquals(new long[]{1, 32, 32, 3}, arr.shape());
}
}
}