From 7651a486e10ac6a4d01f7193840006500ff8da30 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 4 May 2020 17:17:57 +1000 Subject: [PATCH] #8201 ImageLoader/NativeImageLoader NHWC support (#432) Signed-off-by: Alex Black --- .../datavec/image/loader/BaseImageLoader.java | 28 ++++++- .../org/datavec/image/loader/CifarLoader.java | 2 + .../org/datavec/image/loader/ImageLoader.java | 75 ++++++++++++++----- .../org/datavec/image/loader/LFWLoader.java | 22 ++++++ .../image/loader/NativeImageLoader.java | 40 ++++++++-- .../datavec/image/loader/TestImageLoader.java | 59 +++++++++++++++ .../image/loader/TestNativeImageLoader.java | 59 ++++++++++++++- .../recordreader/TestImageRecordReader.java | 14 ++-- 8 files changed, 263 insertions(+), 36 deletions(-) diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java index d2518eeb6..0bfddeb32 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java @@ -16,6 +16,7 @@ package org.datavec.image.loader; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.datavec.image.data.Image; import org.datavec.image.transform.ImageTransform; @@ -35,10 +36,9 @@ import java.util.Random; /** * Created by nyghtowl on 12/17/15. */ +@Slf4j public abstract class BaseImageLoader implements Serializable { - protected static final Logger log = LoggerFactory.getLogger(BaseImageLoader.class); - public enum MultiPageMode { MINIBATCH, FIRST //, CHANNELS, } @@ -62,13 +62,37 @@ public abstract class BaseImageLoader implements Serializable { public abstract INDArray asRowVector(InputStream inputStream) throws IOException; + /** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format */ public abstract INDArray asMatrix(File f) throws IOException; + /** + * Load an image from a file to an INDArray + * @param f File to load the image from + * @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return + * in NHWC/channels_last [1, height, width, channels] format + * @return Image file as as INDArray + */ + public abstract INDArray asMatrix(File f, boolean nchw) throws IOException; + public abstract INDArray asMatrix(InputStream inputStream) throws IOException; + /** + * Load an image file from an input stream to an INDArray + * @param inputStream Input stream to load the image from + * @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return + * in NHWC/channels_last [1, height, width, channels] format + * @return Image file stream as as INDArray + */ + public abstract INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException; + /** As per {@link #asMatrix(File)} but as an {@link Image}*/ public abstract Image asImageMatrix(File f) throws IOException; + /** As per {@link #asMatrix(File, boolean)} but as an {@link Image}*/ + public abstract Image asImageMatrix(File f, boolean nchw) throws IOException; + /** As per {@link #asMatrix(InputStream)} but as an {@link Image}*/ public abstract Image asImageMatrix(InputStream inputStream) throws IOException; + /** As per {@link #asMatrix(InputStream, boolean)} but as an {@link Image}*/ + public abstract Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException; public static void downloadAndUntar(Map urlMap, File fullDir) { diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java index 3d390c698..e513ebed3 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java @@ -16,6 +16,7 @@ package org.datavec.image.loader; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import org.bytedeco.javacv.OpenCVFrameConverter; @@ -47,6 +48,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*; * There is a special preProcessor used to normalize the dataset based on Sergey Zagoruyko example * https://github.com/szagoruyko/cifar.torch */ +@Slf4j public class CifarLoader extends NativeImageLoader implements Serializable { public static final int NUM_TRAIN_IMAGES = 50000; public static final int NUM_TEST_IMAGES = 10000; diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java index d246c65ad..9c2c61d57 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java @@ -249,7 +249,14 @@ public class ImageLoader extends BaseImageLoader { * @throws IOException */ public INDArray asMatrix(File f) throws IOException { - return NDArrayUtil.toNDArray(fromFile(f)); + return asMatrix(f, true); + } + + @Override + public INDArray asMatrix(File f, boolean nchw) throws IOException { + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + return asMatrix(is, nchw); + } } /** @@ -259,34 +266,68 @@ public class ImageLoader extends BaseImageLoader { * @return the input stream to convert */ public INDArray asMatrix(InputStream inputStream) throws IOException { - if (channels == 3) - return toBgr(inputStream); - try { - BufferedImage image = ImageIO.read(inputStream); - return asMatrix(image); - } catch (IOException e) { - throw new IOException("Unable to load image", e); + return asMatrix(inputStream, true); + } + + @Override + public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException { + INDArray ret; + if (channels == 3) { + ret = toBgr(inputStream); + } else { + try { + BufferedImage image = ImageIO.read(inputStream); + ret = asMatrix(image); + } catch (IOException e) { + throw new IOException("Unable to load image", e); + } } + if(ret.rank() == 3){ + ret = ret.reshape(1, ret.size(0), ret.size(1), ret.size(2)); + } + if(!nchw) + ret = ret.permute(0,2,3,1); //NCHW to NHWC + return ret; } @Override public org.datavec.image.data.Image asImageMatrix(File f) throws IOException { + return asImageMatrix(f, true); + } + + @Override + public org.datavec.image.data.Image asImageMatrix(File f, boolean nchw) throws IOException { try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { - return asImageMatrix(bis); + return asImageMatrix(bis, nchw); } } @Override public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throws IOException { - if (channels == 3) - return toBgrImage(inputStream); - try { - BufferedImage image = ImageIO.read(inputStream); - INDArray asMatrix = asMatrix(image); - return new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth()); - } catch (IOException e) { - throw new IOException("Unable to load image", e); + return asImageMatrix(inputStream, true); + } + + @Override + public org.datavec.image.data.Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException { + org.datavec.image.data.Image ret; + if (channels == 3) { + ret = toBgrImage(inputStream); + } else { + try { + BufferedImage image = ImageIO.read(inputStream); + INDArray asMatrix = asMatrix(image); + ret = new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth()); + } catch (IOException e) { + throw new IOException("Unable to load image", e); + } } + if(ret.getImage().rank() == 3){ + INDArray a = ret.getImage(); + ret.setImage(a.reshape(1, a.size(0), a.size(1), a.size(2))); + } + if(!nchw) + ret.setImage(ret.getImage().permute(0,2,3,1)); //NCHW to NHWC + return ret; } /** diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java index d28c73318..b71c53e42 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java @@ -17,6 +17,7 @@ package org.datavec.image.loader; +import lombok.extern.slf4j.Slf4j; import org.datavec.api.io.filters.BalancedPathFilter; import org.datavec.api.io.labels.PathLabelGenerator; import org.datavec.api.io.labels.PatternPathLabelGenerator; @@ -48,6 +49,7 @@ import java.util.Random; * most images are in color, although a few are grayscale * */ +@Slf4j public class LFWLoader extends BaseImageLoader implements Serializable { public final static int NUM_IMAGES = 13233; @@ -270,19 +272,39 @@ public class LFWLoader extends BaseImageLoader implements Serializable { throw new UnsupportedOperationException(); } + @Override + public INDArray asMatrix(File f, boolean nchw) throws IOException { + throw new UnsupportedOperationException(); + } + @Override public INDArray asMatrix(InputStream inputStream) throws IOException { throw new UnsupportedOperationException(); } + @Override + public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException { + throw new UnsupportedOperationException(); + } + @Override public Image asImageMatrix(File f) throws IOException { throw new UnsupportedOperationException(); } + @Override + public Image asImageMatrix(File f, boolean nchw) throws IOException { + throw new UnsupportedOperationException(); + } + @Override public Image asImageMatrix(InputStream inputStream) throws IOException { throw new UnsupportedOperationException(); } + @Override + public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException { + throw new UnsupportedOperationException(); + } + } diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java index 88bc161f2..ae9e2a322 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java @@ -248,17 +248,27 @@ public class NativeImageLoader extends BaseImageLoader { @Override public INDArray asMatrix(File f) throws IOException { + return asMatrix(f, true); + } + + @Override + public INDArray asMatrix(File f, boolean nchw) throws IOException { try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { - return asMatrix(bis); + return asMatrix(bis, nchw); } } @Override public INDArray asMatrix(InputStream is) throws IOException { - Mat mat = streamToMat(is); + return asMatrix(is, true); + } + + @Override + public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException { + Mat mat = streamToMat(inputStream); INDArray a; if (this.multiPageMode != null) { - a = asMatrix(mat.data(), mat.cols()); + a = asMatrix(mat.data(), mat.cols()); }else{ Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); if (image == null || image.empty()) { @@ -272,7 +282,11 @@ public class NativeImageLoader extends BaseImageLoader { a = asMatrix(image); image.deallocate(); } - return a; + if(nchw) { + return a; + } else { + return a.permute(0, 2, 3, 1); //NCHW to NHWC + } } /** @@ -331,19 +345,29 @@ public class NativeImageLoader extends BaseImageLoader { } public Image asImageMatrix(String filename) throws IOException { - return asImageMatrix(filename); + return asImageMatrix(new File(filename)); } @Override public Image asImageMatrix(File f) throws IOException { + return asImageMatrix(f, true); + } + + @Override + public Image asImageMatrix(File f, boolean nchw) throws IOException { try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { - return asImageMatrix(bis); + return asImageMatrix(bis, nchw); } } @Override public Image asImageMatrix(InputStream is) throws IOException { - Mat mat = streamToMat(is); + return asImageMatrix(is, true); + } + + @Override + public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException { + Mat mat = streamToMat(inputStream); Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); if (image == null || image.empty()) { PIX pix = pixReadMem(mat.data(), mat.cols()); @@ -354,6 +378,8 @@ public class NativeImageLoader extends BaseImageLoader { pixDestroy(pix); } INDArray a = asMatrix(image); + if(!nchw) + a = a.permute(0,2,3,1); //NCHW to NHWC Image i = new Image(a, image.channels(), image.rows(), image.cols()); image.deallocate(); diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java index 1683980f0..a82f12409 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java @@ -16,10 +16,16 @@ package org.datavec.image.loader; +import org.datavec.image.data.Image; import org.junit.Test; +import org.nd4j.common.resources.Resources; import org.nd4j.linalg.api.ndarray.INDArray; import java.awt.image.BufferedImage; +import java.io.BufferedInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.InputStream; import java.util.Random; import static org.junit.Assert.assertEquals; @@ -208,4 +214,57 @@ public class TestImageLoader { private BufferedImage makeRandomBufferedImage(boolean alpha) { return makeRandomBufferedImage(alpha, rng.nextInt() % 100 + 100, rng.nextInt() % 100 + 100); } + + + @Test + public void testNCHW_NHWC() throws Exception { + File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg"); + + ImageLoader il = new ImageLoader(32, 32, 3); + + //asMatrix(File, boolean) + INDArray a_nchw = il.asMatrix(f); + INDArray a_nchw2 = il.asMatrix(f, true); + INDArray a_nhwc = il.asMatrix(f, false); + + assertEquals(a_nchw, a_nchw2); + assertEquals(a_nchw, a_nhwc.permute(0,3,1,2)); + + + //asMatrix(InputStream, boolean) + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + a_nchw = il.asMatrix(is); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + a_nchw2 = il.asMatrix(is, true); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + a_nhwc = il.asMatrix(is, false); + } + assertEquals(a_nchw, a_nchw2); + assertEquals(a_nchw, a_nhwc.permute(0,3,1,2)); + + + //asImageMatrix(File, boolean) + Image i_nchw = il.asImageMatrix(f); + Image i_nchw2 = il.asImageMatrix(f, true); + Image i_nhwc = il.asImageMatrix(f, false); + + assertEquals(i_nchw.getImage(), i_nchw2.getImage()); + assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW + + + //asImageMatrix(InputStream, boolean) + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + i_nchw = il.asImageMatrix(is); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + i_nchw2 = il.asImageMatrix(is, true); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + i_nhwc = il.asImageMatrix(is, false); + } + assertEquals(i_nchw.getImage(), i_nchw2.getImage()); + assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW + } } diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java index 6e7705569..68e93107c 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java @@ -24,20 +24,19 @@ import org.bytedeco.javacpp.indexer.UByteIndexer; import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.Java2DFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter; +import org.datavec.image.data.Image; import org.datavec.image.data.ImageWritable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.common.resources.Resources; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.io.ClassPathResource; import java.awt.image.BufferedImage; -import java.io.File; -import java.io.FileInputStream; -import java.io.InputStream; -import java.io.IOException; +import java.io.*; import java.lang.reflect.Field; import java.util.Random; @@ -604,4 +603,56 @@ public class TestNativeImageLoader { } } + @Test + public void testNCHW_NHWC() throws Exception { + File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg"); + + NativeImageLoader il = new NativeImageLoader(32, 32, 3); + + //asMatrix(File, boolean) + INDArray a_nchw = il.asMatrix(f); + INDArray a_nchw2 = il.asMatrix(f, true); + INDArray a_nhwc = il.asMatrix(f, false); + + assertEquals(a_nchw, a_nchw2); + assertEquals(a_nchw, a_nhwc.permute(0,3,1,2)); + + + //asMatrix(InputStream, boolean) + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + a_nchw = il.asMatrix(is); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + a_nchw2 = il.asMatrix(is, true); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + a_nhwc = il.asMatrix(is, false); + } + assertEquals(a_nchw, a_nchw2); + assertEquals(a_nchw, a_nhwc.permute(0,3,1,2)); + + + //asImageMatrix(File, boolean) + Image i_nchw = il.asImageMatrix(f); + Image i_nchw2 = il.asImageMatrix(f, true); + Image i_nhwc = il.asImageMatrix(f, false); + + assertEquals(i_nchw.getImage(), i_nchw2.getImage()); + assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW + + + //asImageMatrix(InputStream, boolean) + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + i_nchw = il.asImageMatrix(is); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + i_nchw2 = il.asImageMatrix(is, true); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + i_nhwc = il.asImageMatrix(is, false); + } + assertEquals(i_nchw.getImage(), i_nchw2.getImage()); + assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW + } + } diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java index fdcbe959a..26cd83f06 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java @@ -474,16 +474,18 @@ public class TestImageRecordReader { public void testNCHW_NCHW() throws Exception { //Idea: labels order should be consistent regardless of input file order File f0 = testDir.newFolder(); - Resources.copyDirectory("datavec-data-image/testimages/", f0); + new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0); - FileSplit fs = new FileSplit(f0, new Random(12345)); - assertEquals(6, fs.locations().length); + FileSplit fs0 = new FileSplit(f0, new Random(12345)); + FileSplit fs1 = new FileSplit(f0, new Random(12345)); + assertEquals(6, fs0.locations().length); + assertEquals(6, fs1.locations().length); ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true); - nchw.initialize(fs); + nchw.initialize(fs0); ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false); - nhwc.initialize(fs); + nhwc.initialize(fs1); while(nchw.hasNext()){ assertTrue(nhwc.hasNext()); @@ -533,7 +535,7 @@ public class TestImageRecordReader { //Test record(URI, DataInputStream) - URI u = fs.locations()[0]; + URI u = fs0.locations()[0]; try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) { List l = nchw.record(u, dis);