Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
b70e02a915
commit
7651a486e1
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.datavec.image.loader;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.datavec.image.data.Image;
|
||||
import org.datavec.image.transform.ImageTransform;
|
||||
|
@ -35,10 +36,9 @@ import java.util.Random;
|
|||
/**
|
||||
* Created by nyghtowl on 12/17/15.
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class BaseImageLoader implements Serializable {
|
||||
|
||||
protected static final Logger log = LoggerFactory.getLogger(BaseImageLoader.class);
|
||||
|
||||
public enum MultiPageMode {
|
||||
MINIBATCH, FIRST //, CHANNELS,
|
||||
}
|
||||
|
@ -62,13 +62,37 @@ public abstract class BaseImageLoader implements Serializable {
|
|||
|
||||
public abstract INDArray asRowVector(InputStream inputStream) throws IOException;
|
||||
|
||||
/** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format */
|
||||
public abstract INDArray asMatrix(File f) throws IOException;
|
||||
|
||||
/**
|
||||
* Load an image from a file to an INDArray
|
||||
* @param f File to load the image from
|
||||
* @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return
|
||||
* in NHWC/channels_last [1, height, width, channels] format
|
||||
* @return Image file as as INDArray
|
||||
*/
|
||||
public abstract INDArray asMatrix(File f, boolean nchw) throws IOException;
|
||||
|
||||
public abstract INDArray asMatrix(InputStream inputStream) throws IOException;
|
||||
/**
|
||||
* Load an image file from an input stream to an INDArray
|
||||
* @param inputStream Input stream to load the image from
|
||||
* @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return
|
||||
* in NHWC/channels_last [1, height, width, channels] format
|
||||
* @return Image file stream as as INDArray
|
||||
*/
|
||||
public abstract INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException;
|
||||
|
||||
/** As per {@link #asMatrix(File)} but as an {@link Image}*/
|
||||
public abstract Image asImageMatrix(File f) throws IOException;
|
||||
/** As per {@link #asMatrix(File, boolean)} but as an {@link Image}*/
|
||||
public abstract Image asImageMatrix(File f, boolean nchw) throws IOException;
|
||||
|
||||
/** As per {@link #asMatrix(InputStream)} but as an {@link Image}*/
|
||||
public abstract Image asImageMatrix(InputStream inputStream) throws IOException;
|
||||
/** As per {@link #asMatrix(InputStream, boolean)} but as an {@link Image}*/
|
||||
public abstract Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException;
|
||||
|
||||
|
||||
public static void downloadAndUntar(Map urlMap, File fullDir) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.datavec.image.loader;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.FilenameUtils;
|
||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||
|
@ -47,6 +48,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
|||
* There is a special preProcessor used to normalize the dataset based on Sergey Zagoruyko example
|
||||
* <a href="https://github.com/szagoruyko/cifar.torch">https://github.com/szagoruyko/cifar.torch</a>
|
||||
*/
|
||||
@Slf4j
|
||||
public class CifarLoader extends NativeImageLoader implements Serializable {
|
||||
public static final int NUM_TRAIN_IMAGES = 50000;
|
||||
public static final int NUM_TEST_IMAGES = 10000;
|
||||
|
|
|
@ -249,7 +249,14 @@ public class ImageLoader extends BaseImageLoader {
|
|||
* @throws IOException
|
||||
*/
|
||||
public INDArray asMatrix(File f) throws IOException {
|
||||
return NDArrayUtil.toNDArray(fromFile(f));
|
||||
return asMatrix(f, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(File f, boolean nchw) throws IOException {
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
return asMatrix(is, nchw);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -259,34 +266,68 @@ public class ImageLoader extends BaseImageLoader {
|
|||
* @return the input stream to convert
|
||||
*/
|
||||
public INDArray asMatrix(InputStream inputStream) throws IOException {
|
||||
if (channels == 3)
|
||||
return toBgr(inputStream);
|
||||
try {
|
||||
BufferedImage image = ImageIO.read(inputStream);
|
||||
return asMatrix(image);
|
||||
} catch (IOException e) {
|
||||
throw new IOException("Unable to load image", e);
|
||||
return asMatrix(inputStream, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||
INDArray ret;
|
||||
if (channels == 3) {
|
||||
ret = toBgr(inputStream);
|
||||
} else {
|
||||
try {
|
||||
BufferedImage image = ImageIO.read(inputStream);
|
||||
ret = asMatrix(image);
|
||||
} catch (IOException e) {
|
||||
throw new IOException("Unable to load image", e);
|
||||
}
|
||||
}
|
||||
if(ret.rank() == 3){
|
||||
ret = ret.reshape(1, ret.size(0), ret.size(1), ret.size(2));
|
||||
}
|
||||
if(!nchw)
|
||||
ret = ret.permute(0,2,3,1); //NCHW to NHWC
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public org.datavec.image.data.Image asImageMatrix(File f) throws IOException {
|
||||
return asImageMatrix(f, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public org.datavec.image.data.Image asImageMatrix(File f, boolean nchw) throws IOException {
|
||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
||||
return asImageMatrix(bis);
|
||||
return asImageMatrix(bis, nchw);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throws IOException {
|
||||
if (channels == 3)
|
||||
return toBgrImage(inputStream);
|
||||
try {
|
||||
BufferedImage image = ImageIO.read(inputStream);
|
||||
INDArray asMatrix = asMatrix(image);
|
||||
return new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
|
||||
} catch (IOException e) {
|
||||
throw new IOException("Unable to load image", e);
|
||||
return asImageMatrix(inputStream, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||
org.datavec.image.data.Image ret;
|
||||
if (channels == 3) {
|
||||
ret = toBgrImage(inputStream);
|
||||
} else {
|
||||
try {
|
||||
BufferedImage image = ImageIO.read(inputStream);
|
||||
INDArray asMatrix = asMatrix(image);
|
||||
ret = new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
|
||||
} catch (IOException e) {
|
||||
throw new IOException("Unable to load image", e);
|
||||
}
|
||||
}
|
||||
if(ret.getImage().rank() == 3){
|
||||
INDArray a = ret.getImage();
|
||||
ret.setImage(a.reshape(1, a.size(0), a.size(1), a.size(2)));
|
||||
}
|
||||
if(!nchw)
|
||||
ret.setImage(ret.getImage().permute(0,2,3,1)); //NCHW to NHWC
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.datavec.image.loader;
|
||||
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.datavec.api.io.filters.BalancedPathFilter;
|
||||
import org.datavec.api.io.labels.PathLabelGenerator;
|
||||
import org.datavec.api.io.labels.PatternPathLabelGenerator;
|
||||
|
@ -48,6 +49,7 @@ import java.util.Random;
|
|||
* most images are in color, although a few are grayscale
|
||||
*
|
||||
*/
|
||||
@Slf4j
|
||||
public class LFWLoader extends BaseImageLoader implements Serializable {
|
||||
|
||||
public final static int NUM_IMAGES = 13233;
|
||||
|
@ -270,19 +272,39 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(File f, boolean nchw) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(InputStream inputStream) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(File f) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(File f, boolean nchw) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(InputStream inputStream) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -248,17 +248,27 @@ public class NativeImageLoader extends BaseImageLoader {
|
|||
|
||||
@Override
|
||||
public INDArray asMatrix(File f) throws IOException {
|
||||
return asMatrix(f, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(File f, boolean nchw) throws IOException {
|
||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
||||
return asMatrix(bis);
|
||||
return asMatrix(bis, nchw);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(InputStream is) throws IOException {
|
||||
Mat mat = streamToMat(is);
|
||||
return asMatrix(is, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||
Mat mat = streamToMat(inputStream);
|
||||
INDArray a;
|
||||
if (this.multiPageMode != null) {
|
||||
a = asMatrix(mat.data(), mat.cols());
|
||||
a = asMatrix(mat.data(), mat.cols());
|
||||
}else{
|
||||
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
||||
if (image == null || image.empty()) {
|
||||
|
@ -272,7 +282,11 @@ public class NativeImageLoader extends BaseImageLoader {
|
|||
a = asMatrix(image);
|
||||
image.deallocate();
|
||||
}
|
||||
return a;
|
||||
if(nchw) {
|
||||
return a;
|
||||
} else {
|
||||
return a.permute(0, 2, 3, 1); //NCHW to NHWC
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -331,19 +345,29 @@ public class NativeImageLoader extends BaseImageLoader {
|
|||
}
|
||||
|
||||
public Image asImageMatrix(String filename) throws IOException {
|
||||
return asImageMatrix(filename);
|
||||
return asImageMatrix(new File(filename));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(File f) throws IOException {
|
||||
return asImageMatrix(f, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(File f, boolean nchw) throws IOException {
|
||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
||||
return asImageMatrix(bis);
|
||||
return asImageMatrix(bis, nchw);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(InputStream is) throws IOException {
|
||||
Mat mat = streamToMat(is);
|
||||
return asImageMatrix(is, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||
Mat mat = streamToMat(inputStream);
|
||||
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
||||
if (image == null || image.empty()) {
|
||||
PIX pix = pixReadMem(mat.data(), mat.cols());
|
||||
|
@ -354,6 +378,8 @@ public class NativeImageLoader extends BaseImageLoader {
|
|||
pixDestroy(pix);
|
||||
}
|
||||
INDArray a = asMatrix(image);
|
||||
if(!nchw)
|
||||
a = a.permute(0,2,3,1); //NCHW to NHWC
|
||||
Image i = new Image(a, image.channels(), image.rows(), image.cols());
|
||||
|
||||
image.deallocate();
|
||||
|
|
|
@ -16,10 +16,16 @@
|
|||
|
||||
package org.datavec.image.loader;
|
||||
|
||||
import org.datavec.image.data.Image;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.InputStream;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
@ -208,4 +214,57 @@ public class TestImageLoader {
|
|||
private BufferedImage makeRandomBufferedImage(boolean alpha) {
|
||||
return makeRandomBufferedImage(alpha, rng.nextInt() % 100 + 100, rng.nextInt() % 100 + 100);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testNCHW_NHWC() throws Exception {
|
||||
File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg");
|
||||
|
||||
ImageLoader il = new ImageLoader(32, 32, 3);
|
||||
|
||||
//asMatrix(File, boolean)
|
||||
INDArray a_nchw = il.asMatrix(f);
|
||||
INDArray a_nchw2 = il.asMatrix(f, true);
|
||||
INDArray a_nhwc = il.asMatrix(f, false);
|
||||
|
||||
assertEquals(a_nchw, a_nchw2);
|
||||
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||
|
||||
|
||||
//asMatrix(InputStream, boolean)
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
a_nchw = il.asMatrix(is);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
a_nchw2 = il.asMatrix(is, true);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
a_nhwc = il.asMatrix(is, false);
|
||||
}
|
||||
assertEquals(a_nchw, a_nchw2);
|
||||
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||
|
||||
|
||||
//asImageMatrix(File, boolean)
|
||||
Image i_nchw = il.asImageMatrix(f);
|
||||
Image i_nchw2 = il.asImageMatrix(f, true);
|
||||
Image i_nhwc = il.asImageMatrix(f, false);
|
||||
|
||||
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||
|
||||
|
||||
//asImageMatrix(InputStream, boolean)
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
i_nchw = il.asImageMatrix(is);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
i_nchw2 = il.asImageMatrix(is, true);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
i_nhwc = il.asImageMatrix(is, false);
|
||||
}
|
||||
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,20 +24,19 @@ import org.bytedeco.javacpp.indexer.UByteIndexer;
|
|||
import org.bytedeco.javacv.Frame;
|
||||
import org.bytedeco.javacv.Java2DFrameConverter;
|
||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||
import org.datavec.image.data.Image;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.InputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.*;
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.Random;
|
||||
|
||||
|
@ -604,4 +603,56 @@ public class TestNativeImageLoader {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNCHW_NHWC() throws Exception {
|
||||
File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg");
|
||||
|
||||
NativeImageLoader il = new NativeImageLoader(32, 32, 3);
|
||||
|
||||
//asMatrix(File, boolean)
|
||||
INDArray a_nchw = il.asMatrix(f);
|
||||
INDArray a_nchw2 = il.asMatrix(f, true);
|
||||
INDArray a_nhwc = il.asMatrix(f, false);
|
||||
|
||||
assertEquals(a_nchw, a_nchw2);
|
||||
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||
|
||||
|
||||
//asMatrix(InputStream, boolean)
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
a_nchw = il.asMatrix(is);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
a_nchw2 = il.asMatrix(is, true);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
a_nhwc = il.asMatrix(is, false);
|
||||
}
|
||||
assertEquals(a_nchw, a_nchw2);
|
||||
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||
|
||||
|
||||
//asImageMatrix(File, boolean)
|
||||
Image i_nchw = il.asImageMatrix(f);
|
||||
Image i_nchw2 = il.asImageMatrix(f, true);
|
||||
Image i_nhwc = il.asImageMatrix(f, false);
|
||||
|
||||
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||
|
||||
|
||||
//asImageMatrix(InputStream, boolean)
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
i_nchw = il.asImageMatrix(is);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
i_nchw2 = il.asImageMatrix(is, true);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
i_nhwc = il.asImageMatrix(is, false);
|
||||
}
|
||||
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -474,16 +474,18 @@ public class TestImageRecordReader {
|
|||
public void testNCHW_NCHW() throws Exception {
|
||||
//Idea: labels order should be consistent regardless of input file order
|
||||
File f0 = testDir.newFolder();
|
||||
Resources.copyDirectory("datavec-data-image/testimages/", f0);
|
||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
|
||||
|
||||
FileSplit fs = new FileSplit(f0, new Random(12345));
|
||||
assertEquals(6, fs.locations().length);
|
||||
FileSplit fs0 = new FileSplit(f0, new Random(12345));
|
||||
FileSplit fs1 = new FileSplit(f0, new Random(12345));
|
||||
assertEquals(6, fs0.locations().length);
|
||||
assertEquals(6, fs1.locations().length);
|
||||
|
||||
ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true);
|
||||
nchw.initialize(fs);
|
||||
nchw.initialize(fs0);
|
||||
|
||||
ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false);
|
||||
nhwc.initialize(fs);
|
||||
nhwc.initialize(fs1);
|
||||
|
||||
while(nchw.hasNext()){
|
||||
assertTrue(nhwc.hasNext());
|
||||
|
@ -533,7 +535,7 @@ public class TestImageRecordReader {
|
|||
|
||||
//Test record(URI, DataInputStream)
|
||||
|
||||
URI u = fs.locations()[0];
|
||||
URI u = fs0.locations()[0];
|
||||
|
||||
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
|
||||
List<Writable> l = nchw.record(u, dis);
|
||||
|
|
Loading…
Reference in New Issue