#8201 ImageLoader/NativeImageLoader NHWC support (#432)

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-05-04 17:17:57 +10:00 committed by GitHub
parent b70e02a915
commit 7651a486e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 263 additions and 36 deletions

View File

@ -16,6 +16,7 @@
package org.datavec.image.loader; package org.datavec.image.loader;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.datavec.image.data.Image; import org.datavec.image.data.Image;
import org.datavec.image.transform.ImageTransform; import org.datavec.image.transform.ImageTransform;
@ -35,10 +36,9 @@ import java.util.Random;
/** /**
* Created by nyghtowl on 12/17/15. * Created by nyghtowl on 12/17/15.
*/ */
@Slf4j
public abstract class BaseImageLoader implements Serializable { public abstract class BaseImageLoader implements Serializable {
protected static final Logger log = LoggerFactory.getLogger(BaseImageLoader.class);
public enum MultiPageMode { public enum MultiPageMode {
MINIBATCH, FIRST //, CHANNELS, MINIBATCH, FIRST //, CHANNELS,
} }
@ -62,13 +62,37 @@ public abstract class BaseImageLoader implements Serializable {
public abstract INDArray asRowVector(InputStream inputStream) throws IOException; public abstract INDArray asRowVector(InputStream inputStream) throws IOException;
/** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format */
public abstract INDArray asMatrix(File f) throws IOException; public abstract INDArray asMatrix(File f) throws IOException;
/**
* Load an image from a file to an INDArray
* @param f File to load the image from
* @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return
* in NHWC/channels_last [1, height, width, channels] format
* @return Image file as as INDArray
*/
public abstract INDArray asMatrix(File f, boolean nchw) throws IOException;
public abstract INDArray asMatrix(InputStream inputStream) throws IOException; public abstract INDArray asMatrix(InputStream inputStream) throws IOException;
/**
* Load an image file from an input stream to an INDArray
* @param inputStream Input stream to load the image from
* @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return
* in NHWC/channels_last [1, height, width, channels] format
* @return Image file stream as as INDArray
*/
public abstract INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException;
/** As per {@link #asMatrix(File)} but as an {@link Image}*/
public abstract Image asImageMatrix(File f) throws IOException; public abstract Image asImageMatrix(File f) throws IOException;
/** As per {@link #asMatrix(File, boolean)} but as an {@link Image}*/
public abstract Image asImageMatrix(File f, boolean nchw) throws IOException;
/** As per {@link #asMatrix(InputStream)} but as an {@link Image}*/
public abstract Image asImageMatrix(InputStream inputStream) throws IOException; public abstract Image asImageMatrix(InputStream inputStream) throws IOException;
/** As per {@link #asMatrix(InputStream, boolean)} but as an {@link Image}*/
public abstract Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException;
public static void downloadAndUntar(Map urlMap, File fullDir) { public static void downloadAndUntar(Map urlMap, File fullDir) {

View File

@ -16,6 +16,7 @@
package org.datavec.image.loader; package org.datavec.image.loader;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter;
@ -47,6 +48,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*;
* There is a special preProcessor used to normalize the dataset based on Sergey Zagoruyko example * There is a special preProcessor used to normalize the dataset based on Sergey Zagoruyko example
* <a href="https://github.com/szagoruyko/cifar.torch">https://github.com/szagoruyko/cifar.torch</a> * <a href="https://github.com/szagoruyko/cifar.torch">https://github.com/szagoruyko/cifar.torch</a>
*/ */
@Slf4j
public class CifarLoader extends NativeImageLoader implements Serializable { public class CifarLoader extends NativeImageLoader implements Serializable {
public static final int NUM_TRAIN_IMAGES = 50000; public static final int NUM_TRAIN_IMAGES = 50000;
public static final int NUM_TEST_IMAGES = 10000; public static final int NUM_TEST_IMAGES = 10000;

View File

@ -249,7 +249,14 @@ public class ImageLoader extends BaseImageLoader {
* @throws IOException * @throws IOException
*/ */
public INDArray asMatrix(File f) throws IOException { public INDArray asMatrix(File f) throws IOException {
return NDArrayUtil.toNDArray(fromFile(f)); return asMatrix(f, true);
}
@Override
public INDArray asMatrix(File f, boolean nchw) throws IOException {
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
return asMatrix(is, nchw);
}
} }
/** /**
@ -259,34 +266,68 @@ public class ImageLoader extends BaseImageLoader {
* @return the input stream to convert * @return the input stream to convert
*/ */
public INDArray asMatrix(InputStream inputStream) throws IOException { public INDArray asMatrix(InputStream inputStream) throws IOException {
if (channels == 3) return asMatrix(inputStream, true);
return toBgr(inputStream); }
try {
BufferedImage image = ImageIO.read(inputStream); @Override
return asMatrix(image); public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
} catch (IOException e) { INDArray ret;
throw new IOException("Unable to load image", e); if (channels == 3) {
ret = toBgr(inputStream);
} else {
try {
BufferedImage image = ImageIO.read(inputStream);
ret = asMatrix(image);
} catch (IOException e) {
throw new IOException("Unable to load image", e);
}
} }
if(ret.rank() == 3){
ret = ret.reshape(1, ret.size(0), ret.size(1), ret.size(2));
}
if(!nchw)
ret = ret.permute(0,2,3,1); //NCHW to NHWC
return ret;
} }
@Override @Override
public org.datavec.image.data.Image asImageMatrix(File f) throws IOException { public org.datavec.image.data.Image asImageMatrix(File f) throws IOException {
return asImageMatrix(f, true);
}
@Override
public org.datavec.image.data.Image asImageMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
return asImageMatrix(bis); return asImageMatrix(bis, nchw);
} }
} }
@Override @Override
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throws IOException { public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throws IOException {
if (channels == 3) return asImageMatrix(inputStream, true);
return toBgrImage(inputStream); }
try {
BufferedImage image = ImageIO.read(inputStream); @Override
INDArray asMatrix = asMatrix(image); public org.datavec.image.data.Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
return new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth()); org.datavec.image.data.Image ret;
} catch (IOException e) { if (channels == 3) {
throw new IOException("Unable to load image", e); ret = toBgrImage(inputStream);
} else {
try {
BufferedImage image = ImageIO.read(inputStream);
INDArray asMatrix = asMatrix(image);
ret = new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
} catch (IOException e) {
throw new IOException("Unable to load image", e);
}
} }
if(ret.getImage().rank() == 3){
INDArray a = ret.getImage();
ret.setImage(a.reshape(1, a.size(0), a.size(1), a.size(2)));
}
if(!nchw)
ret.setImage(ret.getImage().permute(0,2,3,1)); //NCHW to NHWC
return ret;
} }
/** /**

View File

@ -17,6 +17,7 @@
package org.datavec.image.loader; package org.datavec.image.loader;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.io.filters.BalancedPathFilter; import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.PathLabelGenerator; import org.datavec.api.io.labels.PathLabelGenerator;
import org.datavec.api.io.labels.PatternPathLabelGenerator; import org.datavec.api.io.labels.PatternPathLabelGenerator;
@ -48,6 +49,7 @@ import java.util.Random;
* most images are in color, although a few are grayscale * most images are in color, although a few are grayscale
* *
*/ */
@Slf4j
public class LFWLoader extends BaseImageLoader implements Serializable { public class LFWLoader extends BaseImageLoader implements Serializable {
public final static int NUM_IMAGES = 13233; public final static int NUM_IMAGES = 13233;
@ -270,19 +272,39 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public INDArray asMatrix(File f, boolean nchw) throws IOException {
throw new UnsupportedOperationException();
}
@Override @Override
public INDArray asMatrix(InputStream inputStream) throws IOException { public INDArray asMatrix(InputStream inputStream) throws IOException {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
throw new UnsupportedOperationException();
}
@Override @Override
public Image asImageMatrix(File f) throws IOException { public Image asImageMatrix(File f) throws IOException {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public Image asImageMatrix(File f, boolean nchw) throws IOException {
throw new UnsupportedOperationException();
}
@Override @Override
public Image asImageMatrix(InputStream inputStream) throws IOException { public Image asImageMatrix(InputStream inputStream) throws IOException {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
throw new UnsupportedOperationException();
}
} }

View File

@ -248,17 +248,27 @@ public class NativeImageLoader extends BaseImageLoader {
@Override @Override
public INDArray asMatrix(File f) throws IOException { public INDArray asMatrix(File f) throws IOException {
return asMatrix(f, true);
}
@Override
public INDArray asMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
return asMatrix(bis); return asMatrix(bis, nchw);
} }
} }
@Override @Override
public INDArray asMatrix(InputStream is) throws IOException { public INDArray asMatrix(InputStream is) throws IOException {
Mat mat = streamToMat(is); return asMatrix(is, true);
}
@Override
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
Mat mat = streamToMat(inputStream);
INDArray a; INDArray a;
if (this.multiPageMode != null) { if (this.multiPageMode != null) {
a = asMatrix(mat.data(), mat.cols()); a = asMatrix(mat.data(), mat.cols());
}else{ }else{
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) { if (image == null || image.empty()) {
@ -272,7 +282,11 @@ public class NativeImageLoader extends BaseImageLoader {
a = asMatrix(image); a = asMatrix(image);
image.deallocate(); image.deallocate();
} }
return a; if(nchw) {
return a;
} else {
return a.permute(0, 2, 3, 1); //NCHW to NHWC
}
} }
/** /**
@ -331,19 +345,29 @@ public class NativeImageLoader extends BaseImageLoader {
} }
public Image asImageMatrix(String filename) throws IOException { public Image asImageMatrix(String filename) throws IOException {
return asImageMatrix(filename); return asImageMatrix(new File(filename));
} }
@Override @Override
public Image asImageMatrix(File f) throws IOException { public Image asImageMatrix(File f) throws IOException {
return asImageMatrix(f, true);
}
@Override
public Image asImageMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
return asImageMatrix(bis); return asImageMatrix(bis, nchw);
} }
} }
@Override @Override
public Image asImageMatrix(InputStream is) throws IOException { public Image asImageMatrix(InputStream is) throws IOException {
Mat mat = streamToMat(is); return asImageMatrix(is, true);
}
@Override
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
Mat mat = streamToMat(inputStream);
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) { if (image == null || image.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols()); PIX pix = pixReadMem(mat.data(), mat.cols());
@ -354,6 +378,8 @@ public class NativeImageLoader extends BaseImageLoader {
pixDestroy(pix); pixDestroy(pix);
} }
INDArray a = asMatrix(image); INDArray a = asMatrix(image);
if(!nchw)
a = a.permute(0,2,3,1); //NCHW to NHWC
Image i = new Image(a, image.channels(), image.rows(), image.cols()); Image i = new Image(a, image.channels(), image.rows(), image.cols());
image.deallocate(); image.deallocate();

View File

@ -16,10 +16,16 @@
package org.datavec.image.loader; package org.datavec.image.loader;
import org.datavec.image.data.Image;
import org.junit.Test; import org.junit.Test;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.awt.image.BufferedImage; import java.awt.image.BufferedImage;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@ -208,4 +214,57 @@ public class TestImageLoader {
private BufferedImage makeRandomBufferedImage(boolean alpha) { private BufferedImage makeRandomBufferedImage(boolean alpha) {
return makeRandomBufferedImage(alpha, rng.nextInt() % 100 + 100, rng.nextInt() % 100 + 100); return makeRandomBufferedImage(alpha, rng.nextInt() % 100 + 100, rng.nextInt() % 100 + 100);
} }
@Test
public void testNCHW_NHWC() throws Exception {
File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg");
ImageLoader il = new ImageLoader(32, 32, 3);
//asMatrix(File, boolean)
INDArray a_nchw = il.asMatrix(f);
INDArray a_nchw2 = il.asMatrix(f, true);
INDArray a_nhwc = il.asMatrix(f, false);
assertEquals(a_nchw, a_nchw2);
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
//asMatrix(InputStream, boolean)
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nchw = il.asMatrix(is);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nchw2 = il.asMatrix(is, true);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nhwc = il.asMatrix(is, false);
}
assertEquals(a_nchw, a_nchw2);
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
//asImageMatrix(File, boolean)
Image i_nchw = il.asImageMatrix(f);
Image i_nchw2 = il.asImageMatrix(f, true);
Image i_nhwc = il.asImageMatrix(f, false);
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
//asImageMatrix(InputStream, boolean)
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nchw = il.asImageMatrix(is);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nchw2 = il.asImageMatrix(is, true);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nhwc = il.asImageMatrix(is, false);
}
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
}
} }

View File

@ -24,20 +24,19 @@ import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.Java2DFrameConverter; import org.bytedeco.javacv.Java2DFrameConverter;
import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.Image;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.awt.image.BufferedImage; import java.awt.image.BufferedImage;
import java.io.File; import java.io.*;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.IOException;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.util.Random; import java.util.Random;
@ -604,4 +603,56 @@ public class TestNativeImageLoader {
} }
} }
@Test
public void testNCHW_NHWC() throws Exception {
File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg");
NativeImageLoader il = new NativeImageLoader(32, 32, 3);
//asMatrix(File, boolean)
INDArray a_nchw = il.asMatrix(f);
INDArray a_nchw2 = il.asMatrix(f, true);
INDArray a_nhwc = il.asMatrix(f, false);
assertEquals(a_nchw, a_nchw2);
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
//asMatrix(InputStream, boolean)
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nchw = il.asMatrix(is);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nchw2 = il.asMatrix(is, true);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nhwc = il.asMatrix(is, false);
}
assertEquals(a_nchw, a_nchw2);
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
//asImageMatrix(File, boolean)
Image i_nchw = il.asImageMatrix(f);
Image i_nchw2 = il.asImageMatrix(f, true);
Image i_nhwc = il.asImageMatrix(f, false);
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
//asImageMatrix(InputStream, boolean)
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nchw = il.asImageMatrix(is);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nchw2 = il.asImageMatrix(is, true);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nhwc = il.asImageMatrix(is, false);
}
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
}
} }

View File

@ -474,16 +474,18 @@ public class TestImageRecordReader {
public void testNCHW_NCHW() throws Exception { public void testNCHW_NCHW() throws Exception {
//Idea: labels order should be consistent regardless of input file order //Idea: labels order should be consistent regardless of input file order
File f0 = testDir.newFolder(); File f0 = testDir.newFolder();
Resources.copyDirectory("datavec-data-image/testimages/", f0); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
FileSplit fs = new FileSplit(f0, new Random(12345)); FileSplit fs0 = new FileSplit(f0, new Random(12345));
assertEquals(6, fs.locations().length); FileSplit fs1 = new FileSplit(f0, new Random(12345));
assertEquals(6, fs0.locations().length);
assertEquals(6, fs1.locations().length);
ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true); ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true);
nchw.initialize(fs); nchw.initialize(fs0);
ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false); ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false);
nhwc.initialize(fs); nhwc.initialize(fs1);
while(nchw.hasNext()){ while(nchw.hasNext()){
assertTrue(nhwc.hasNext()); assertTrue(nhwc.hasNext());
@ -533,7 +535,7 @@ public class TestImageRecordReader {
//Test record(URI, DataInputStream) //Test record(URI, DataInputStream)
URI u = fs.locations()[0]; URI u = fs0.locations()[0];
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) { try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
List<Writable> l = nchw.record(u, dis); List<Writable> l = nchw.record(u, dis);