ImageRecordReader - NHWC support (#431)
* #8201 ImageRecordReader - NHWC support Signed-off-by: Alex Black <blacka101@gmail.com> * Check values - NCHW vs. NHWC Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
4dbdaca967
commit
092e6b9891
|
@ -52,6 +52,7 @@ public class NDArrayRecordBatch extends AbstractWritableRecordBatch {
|
|||
public NDArrayRecordBatch(@NonNull List<INDArray> arrays){
|
||||
Preconditions.checkArgument(arrays.size() > 0, "Input list must not be empty");
|
||||
this.arrays = arrays;
|
||||
this.size = arrays.get(0).size(0);
|
||||
|
||||
//Check that dimension 0 matches:
|
||||
if(arrays.size() > 1){
|
||||
|
|
|
@ -77,6 +77,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
|||
protected int patternPosition = 0;
|
||||
@Getter @Setter
|
||||
protected boolean logLabelCountOnInit = true;
|
||||
@Getter @Setter
|
||||
protected boolean nchw_channels_first = true;
|
||||
|
||||
public final static String HEIGHT = NAME_SPACE + ".height";
|
||||
public final static String WIDTH = NAME_SPACE + ".width";
|
||||
|
@ -101,6 +103,11 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
|||
|
||||
protected BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
|
||||
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
|
||||
this(height, width, channels, true, labelGenerator, labelMultiGenerator, imageTransform);
|
||||
}
|
||||
|
||||
protected BaseImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator,
|
||||
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
|
||||
this.height = height;
|
||||
this.width = width;
|
||||
this.channels = channels;
|
||||
|
@ -108,6 +115,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
|||
this.labelMultiGenerator = labelMultiGenerator;
|
||||
this.imageTransform = imageTransform;
|
||||
this.appendLabel = (labelGenerator != null || labelMultiGenerator != null);
|
||||
this.nchw_channels_first = nchw_channels_first;
|
||||
}
|
||||
|
||||
protected boolean containsFormat(String format) {
|
||||
|
@ -237,9 +245,13 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
|||
return next();
|
||||
try {
|
||||
invokeListeners(image);
|
||||
INDArray row = imageLoader.asMatrix(image);
|
||||
Nd4j.getAffinityManager().ensureLocation(row, AffinityManager.Location.DEVICE);
|
||||
ret = RecordConverter.toRecord(row);
|
||||
INDArray array = imageLoader.asMatrix(image);
|
||||
if(!nchw_channels_first){
|
||||
array = array.permute(0,2,3,1); //NCHW to NHWC
|
||||
}
|
||||
|
||||
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.DEVICE);
|
||||
ret = RecordConverter.toRecord(array);
|
||||
if (appendLabel || writeLabel){
|
||||
if(labelMultiGenerator != null){
|
||||
ret.addAll(labelMultiGenerator.getLabels(image.getPath()));
|
||||
|
@ -286,7 +298,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
|||
|
||||
@Override
|
||||
public List<List<Writable>> next(int num) {
|
||||
Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got " + num);
|
||||
Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got %s", num);
|
||||
|
||||
if (imageLoader == null) {
|
||||
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
|
||||
|
@ -337,6 +349,9 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
|||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
if(!nchw_channels_first){
|
||||
features = features.permute(0,2,3,1); //NCHW to NHWC
|
||||
}
|
||||
Nd4j.getAffinityManager().ensureLocation(features, AffinityManager.Location.DEVICE);
|
||||
|
||||
|
||||
|
@ -483,8 +498,10 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
|||
if (imageLoader == null) {
|
||||
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
|
||||
}
|
||||
INDArray row = imageLoader.asMatrix(dataInputStream);
|
||||
List<Writable> ret = RecordConverter.toRecord(row);
|
||||
INDArray array = imageLoader.asMatrix(dataInputStream);
|
||||
if(!nchw_channels_first)
|
||||
array = array.permute(0,2,3,1);
|
||||
List<Writable> ret = RecordConverter.toRecord(array);
|
||||
if (appendLabel)
|
||||
ret.add(new IntWritable(labels.indexOf(getLabel(uri.getPath()))));
|
||||
return ret;
|
||||
|
|
|
@ -34,47 +34,70 @@ import org.datavec.image.transform.ImageTransform;
|
|||
public class ImageRecordReader extends BaseImageRecordReader {
|
||||
|
||||
|
||||
/** Loads images with height = 28, width = 28, and channels = 1, appending no labels. */
|
||||
/** Loads images with height = 28, width = 28, and channels = 1, appending no labels.
|
||||
* Output format is NCHW (channels first) - [numExamples, 1, 28, 28]*/
|
||||
public ImageRecordReader() {
|
||||
super();
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator.
|
||||
* Output format is NCHW (channels first) - [numExamples, channels, height, width]
|
||||
*/
|
||||
public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator) {
|
||||
super(height, width, channels, labelGenerator);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator.
|
||||
* Output format is NCHW (channels first) - [numExamples, channels, height, width]
|
||||
*/
|
||||
public ImageRecordReader(long height, long width, long channels, PathMultiLabelGenerator labelGenerator) {
|
||||
super(height, width, channels, labelGenerator);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending no labels. */
|
||||
/** Loads images with given height, width, and channels, appending no labels - in NCHW (channels first) format */
|
||||
public ImageRecordReader(long height, long width, long channels) {
|
||||
super(height, width, channels, (PathLabelGenerator) null);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
|
||||
/** Loads images with given height, width, and channels, appending no labels - in specified format<br>
|
||||
* If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]<br>
|
||||
* If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]<br>
|
||||
*/
|
||||
public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first) {
|
||||
super(height, width, channels, nchw_channels_first, null, null, null);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator.
|
||||
* Output format is NCHW (channels first) - [numExamples, channels, height, width] */
|
||||
public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
|
||||
ImageTransform imageTransform) {
|
||||
super(height, width, channels, labelGenerator, imageTransform);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending no labels. */
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator.<br>
|
||||
* If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]<br>
|
||||
* If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]<br>
|
||||
*/
|
||||
public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator,
|
||||
ImageTransform imageTransform) {
|
||||
super(height, width, channels, nchw_channels_first, labelGenerator, null, imageTransform);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending no labels.
|
||||
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
|
||||
public ImageRecordReader(long height, long width, long channels, ImageTransform imageTransform) {
|
||||
super(height, width, channels, null, imageTransform);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator
|
||||
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
|
||||
public ImageRecordReader(long height, long width, PathLabelGenerator labelGenerator) {
|
||||
super(height, width, 1, labelGenerator);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels = 1, appending no labels. */
|
||||
/** Loads images with given height, width, and channels = 1, appending no labels.
|
||||
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
|
||||
public ImageRecordReader(long height, long width) {
|
||||
super(height, width, 1, null, null);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -35,13 +35,13 @@ import org.datavec.api.writable.batch.NDArrayRecordBatch;
|
|||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.*;
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -467,5 +467,85 @@ public class TestImageRecordReader {
|
|||
return count;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testNCHW_NCHW() throws Exception {
|
||||
//Idea: labels order should be consistent regardless of input file order
|
||||
File f0 = testDir.newFolder();
|
||||
Resources.copyDirectory("datavec-data-image/testimages/", f0);
|
||||
|
||||
FileSplit fs = new FileSplit(f0, new Random(12345));
|
||||
assertEquals(6, fs.locations().length);
|
||||
|
||||
ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true);
|
||||
nchw.initialize(fs);
|
||||
|
||||
ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false);
|
||||
nhwc.initialize(fs);
|
||||
|
||||
while(nchw.hasNext()){
|
||||
assertTrue(nhwc.hasNext());
|
||||
|
||||
List<Writable> l_nchw = nchw.next();
|
||||
List<Writable> l_nhwc = nhwc.next();
|
||||
|
||||
INDArray a_nchw = ((NDArrayWritable)l_nchw.get(0)).get();
|
||||
INDArray a_nhwc = ((NDArrayWritable)l_nhwc.get(0)).get();
|
||||
|
||||
assertArrayEquals(new long[]{1, 3, 32, 32}, a_nchw.shape());
|
||||
assertArrayEquals(new long[]{1, 32, 32, 3}, a_nhwc.shape());
|
||||
|
||||
INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW
|
||||
assertEquals(a_nchw, permuted);
|
||||
}
|
||||
|
||||
|
||||
//Test batch:
|
||||
nchw.reset();
|
||||
nhwc.reset();
|
||||
|
||||
int batchCount = 0;
|
||||
while(nchw.hasNext()){
|
||||
assertTrue(nhwc.hasNext());
|
||||
batchCount++;
|
||||
|
||||
List<List<Writable>> l_nchw = nchw.next(3);
|
||||
List<List<Writable>> l_nhwc = nhwc.next(3);
|
||||
assertEquals(3, l_nchw.size());
|
||||
assertEquals(3, l_nhwc.size());
|
||||
|
||||
NDArrayRecordBatch b_nchw = (NDArrayRecordBatch)l_nchw;
|
||||
NDArrayRecordBatch b_nhwc = (NDArrayRecordBatch)l_nhwc;
|
||||
|
||||
INDArray a_nchw = b_nchw.getArrays().get(0);
|
||||
INDArray a_nhwc = b_nhwc.getArrays().get(0);
|
||||
|
||||
assertArrayEquals(new long[]{3, 3, 32, 32}, a_nchw.shape());
|
||||
assertArrayEquals(new long[]{3, 32, 32, 3}, a_nhwc.shape());
|
||||
|
||||
INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW
|
||||
assertEquals(a_nchw, permuted);
|
||||
}
|
||||
assertEquals(2, batchCount);
|
||||
|
||||
|
||||
//Test record(URI, DataInputStream)
|
||||
|
||||
URI u = fs.locations()[0];
|
||||
|
||||
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
|
||||
List<Writable> l = nchw.record(u, dis);
|
||||
INDArray arr = ((NDArrayWritable)l.get(0)).get();
|
||||
assertArrayEquals(new long[]{1, 3, 32, 32}, arr.shape());
|
||||
}
|
||||
|
||||
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
|
||||
List<Writable> l = nhwc.record(u, dis);
|
||||
INDArray arr = ((NDArrayWritable)l.get(0)).get();
|
||||
assertArrayEquals(new long[]{1, 32, 32, 3}, arr.shape());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue