Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
b70e02a915
commit
7651a486e1
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.image.loader;
|
package org.datavec.image.loader;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.datavec.image.data.Image;
|
import org.datavec.image.data.Image;
|
||||||
import org.datavec.image.transform.ImageTransform;
|
import org.datavec.image.transform.ImageTransform;
|
||||||
|
@ -35,10 +36,9 @@ import java.util.Random;
|
||||||
/**
|
/**
|
||||||
* Created by nyghtowl on 12/17/15.
|
* Created by nyghtowl on 12/17/15.
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public abstract class BaseImageLoader implements Serializable {
|
public abstract class BaseImageLoader implements Serializable {
|
||||||
|
|
||||||
protected static final Logger log = LoggerFactory.getLogger(BaseImageLoader.class);
|
|
||||||
|
|
||||||
public enum MultiPageMode {
|
public enum MultiPageMode {
|
||||||
MINIBATCH, FIRST //, CHANNELS,
|
MINIBATCH, FIRST //, CHANNELS,
|
||||||
}
|
}
|
||||||
|
@ -62,13 +62,37 @@ public abstract class BaseImageLoader implements Serializable {
|
||||||
|
|
||||||
public abstract INDArray asRowVector(InputStream inputStream) throws IOException;
|
public abstract INDArray asRowVector(InputStream inputStream) throws IOException;
|
||||||
|
|
||||||
|
/** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format */
|
||||||
public abstract INDArray asMatrix(File f) throws IOException;
|
public abstract INDArray asMatrix(File f) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load an image from a file to an INDArray
|
||||||
|
* @param f File to load the image from
|
||||||
|
* @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return
|
||||||
|
* in NHWC/channels_last [1, height, width, channels] format
|
||||||
|
* @return Image file as as INDArray
|
||||||
|
*/
|
||||||
|
public abstract INDArray asMatrix(File f, boolean nchw) throws IOException;
|
||||||
|
|
||||||
public abstract INDArray asMatrix(InputStream inputStream) throws IOException;
|
public abstract INDArray asMatrix(InputStream inputStream) throws IOException;
|
||||||
|
/**
|
||||||
|
* Load an image file from an input stream to an INDArray
|
||||||
|
* @param inputStream Input stream to load the image from
|
||||||
|
* @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return
|
||||||
|
* in NHWC/channels_last [1, height, width, channels] format
|
||||||
|
* @return Image file stream as as INDArray
|
||||||
|
*/
|
||||||
|
public abstract INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException;
|
||||||
|
|
||||||
|
/** As per {@link #asMatrix(File)} but as an {@link Image}*/
|
||||||
public abstract Image asImageMatrix(File f) throws IOException;
|
public abstract Image asImageMatrix(File f) throws IOException;
|
||||||
|
/** As per {@link #asMatrix(File, boolean)} but as an {@link Image}*/
|
||||||
|
public abstract Image asImageMatrix(File f, boolean nchw) throws IOException;
|
||||||
|
|
||||||
|
/** As per {@link #asMatrix(InputStream)} but as an {@link Image}*/
|
||||||
public abstract Image asImageMatrix(InputStream inputStream) throws IOException;
|
public abstract Image asImageMatrix(InputStream inputStream) throws IOException;
|
||||||
|
/** As per {@link #asMatrix(InputStream, boolean)} but as an {@link Image}*/
|
||||||
|
public abstract Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException;
|
||||||
|
|
||||||
|
|
||||||
public static void downloadAndUntar(Map urlMap, File fullDir) {
|
public static void downloadAndUntar(Map urlMap, File fullDir) {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.image.loader;
|
package org.datavec.image.loader;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
|
@ -47,6 +48,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
||||||
* There is a special preProcessor used to normalize the dataset based on Sergey Zagoruyko example
|
* There is a special preProcessor used to normalize the dataset based on Sergey Zagoruyko example
|
||||||
* <a href="https://github.com/szagoruyko/cifar.torch">https://github.com/szagoruyko/cifar.torch</a>
|
* <a href="https://github.com/szagoruyko/cifar.torch">https://github.com/szagoruyko/cifar.torch</a>
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class CifarLoader extends NativeImageLoader implements Serializable {
|
public class CifarLoader extends NativeImageLoader implements Serializable {
|
||||||
public static final int NUM_TRAIN_IMAGES = 50000;
|
public static final int NUM_TRAIN_IMAGES = 50000;
|
||||||
public static final int NUM_TEST_IMAGES = 10000;
|
public static final int NUM_TEST_IMAGES = 10000;
|
||||||
|
|
|
@ -249,7 +249,14 @@ public class ImageLoader extends BaseImageLoader {
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
*/
|
*/
|
||||||
public INDArray asMatrix(File f) throws IOException {
|
public INDArray asMatrix(File f) throws IOException {
|
||||||
return NDArrayUtil.toNDArray(fromFile(f));
|
return asMatrix(f, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray asMatrix(File f, boolean nchw) throws IOException {
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
return asMatrix(is, nchw);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -259,35 +266,69 @@ public class ImageLoader extends BaseImageLoader {
|
||||||
* @return the input stream to convert
|
* @return the input stream to convert
|
||||||
*/
|
*/
|
||||||
public INDArray asMatrix(InputStream inputStream) throws IOException {
|
public INDArray asMatrix(InputStream inputStream) throws IOException {
|
||||||
if (channels == 3)
|
return asMatrix(inputStream, true);
|
||||||
return toBgr(inputStream);
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||||
|
INDArray ret;
|
||||||
|
if (channels == 3) {
|
||||||
|
ret = toBgr(inputStream);
|
||||||
|
} else {
|
||||||
try {
|
try {
|
||||||
BufferedImage image = ImageIO.read(inputStream);
|
BufferedImage image = ImageIO.read(inputStream);
|
||||||
return asMatrix(image);
|
ret = asMatrix(image);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new IOException("Unable to load image", e);
|
throw new IOException("Unable to load image", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if(ret.rank() == 3){
|
||||||
|
ret = ret.reshape(1, ret.size(0), ret.size(1), ret.size(2));
|
||||||
|
}
|
||||||
|
if(!nchw)
|
||||||
|
ret = ret.permute(0,2,3,1); //NCHW to NHWC
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public org.datavec.image.data.Image asImageMatrix(File f) throws IOException {
|
public org.datavec.image.data.Image asImageMatrix(File f) throws IOException {
|
||||||
|
return asImageMatrix(f, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public org.datavec.image.data.Image asImageMatrix(File f, boolean nchw) throws IOException {
|
||||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
||||||
return asImageMatrix(bis);
|
return asImageMatrix(bis, nchw);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throws IOException {
|
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throws IOException {
|
||||||
if (channels == 3)
|
return asImageMatrix(inputStream, true);
|
||||||
return toBgrImage(inputStream);
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||||
|
org.datavec.image.data.Image ret;
|
||||||
|
if (channels == 3) {
|
||||||
|
ret = toBgrImage(inputStream);
|
||||||
|
} else {
|
||||||
try {
|
try {
|
||||||
BufferedImage image = ImageIO.read(inputStream);
|
BufferedImage image = ImageIO.read(inputStream);
|
||||||
INDArray asMatrix = asMatrix(image);
|
INDArray asMatrix = asMatrix(image);
|
||||||
return new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
|
ret = new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new IOException("Unable to load image", e);
|
throw new IOException("Unable to load image", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if(ret.getImage().rank() == 3){
|
||||||
|
INDArray a = ret.getImage();
|
||||||
|
ret.setImage(a.reshape(1, a.size(0), a.size(1), a.size(2)));
|
||||||
|
}
|
||||||
|
if(!nchw)
|
||||||
|
ret.setImage(ret.getImage().permute(0,2,3,1)); //NCHW to NHWC
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert an BufferedImage to a matrix
|
* Convert an BufferedImage to a matrix
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.datavec.image.loader;
|
package org.datavec.image.loader;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.datavec.api.io.filters.BalancedPathFilter;
|
import org.datavec.api.io.filters.BalancedPathFilter;
|
||||||
import org.datavec.api.io.labels.PathLabelGenerator;
|
import org.datavec.api.io.labels.PathLabelGenerator;
|
||||||
import org.datavec.api.io.labels.PatternPathLabelGenerator;
|
import org.datavec.api.io.labels.PatternPathLabelGenerator;
|
||||||
|
@ -48,6 +49,7 @@ import java.util.Random;
|
||||||
* most images are in color, although a few are grayscale
|
* most images are in color, although a few are grayscale
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class LFWLoader extends BaseImageLoader implements Serializable {
|
public class LFWLoader extends BaseImageLoader implements Serializable {
|
||||||
|
|
||||||
public final static int NUM_IMAGES = 13233;
|
public final static int NUM_IMAGES = 13233;
|
||||||
|
@ -270,19 +272,39 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray asMatrix(File f, boolean nchw) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray asMatrix(InputStream inputStream) throws IOException {
|
public INDArray asMatrix(InputStream inputStream) throws IOException {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Image asImageMatrix(File f) throws IOException {
|
public Image asImageMatrix(File f) throws IOException {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Image asImageMatrix(File f, boolean nchw) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Image asImageMatrix(InputStream inputStream) throws IOException {
|
public Image asImageMatrix(InputStream inputStream) throws IOException {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -248,14 +248,24 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray asMatrix(File f) throws IOException {
|
public INDArray asMatrix(File f) throws IOException {
|
||||||
|
return asMatrix(f, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray asMatrix(File f, boolean nchw) throws IOException {
|
||||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
||||||
return asMatrix(bis);
|
return asMatrix(bis, nchw);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray asMatrix(InputStream is) throws IOException {
|
public INDArray asMatrix(InputStream is) throws IOException {
|
||||||
Mat mat = streamToMat(is);
|
return asMatrix(is, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||||
|
Mat mat = streamToMat(inputStream);
|
||||||
INDArray a;
|
INDArray a;
|
||||||
if (this.multiPageMode != null) {
|
if (this.multiPageMode != null) {
|
||||||
a = asMatrix(mat.data(), mat.cols());
|
a = asMatrix(mat.data(), mat.cols());
|
||||||
|
@ -272,7 +282,11 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
a = asMatrix(image);
|
a = asMatrix(image);
|
||||||
image.deallocate();
|
image.deallocate();
|
||||||
}
|
}
|
||||||
|
if(nchw) {
|
||||||
return a;
|
return a;
|
||||||
|
} else {
|
||||||
|
return a.permute(0, 2, 3, 1); //NCHW to NHWC
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -331,19 +345,29 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Image asImageMatrix(String filename) throws IOException {
|
public Image asImageMatrix(String filename) throws IOException {
|
||||||
return asImageMatrix(filename);
|
return asImageMatrix(new File(filename));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Image asImageMatrix(File f) throws IOException {
|
public Image asImageMatrix(File f) throws IOException {
|
||||||
|
return asImageMatrix(f, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Image asImageMatrix(File f, boolean nchw) throws IOException {
|
||||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
||||||
return asImageMatrix(bis);
|
return asImageMatrix(bis, nchw);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Image asImageMatrix(InputStream is) throws IOException {
|
public Image asImageMatrix(InputStream is) throws IOException {
|
||||||
Mat mat = streamToMat(is);
|
return asImageMatrix(is, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||||
|
Mat mat = streamToMat(inputStream);
|
||||||
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
||||||
if (image == null || image.empty()) {
|
if (image == null || image.empty()) {
|
||||||
PIX pix = pixReadMem(mat.data(), mat.cols());
|
PIX pix = pixReadMem(mat.data(), mat.cols());
|
||||||
|
@ -354,6 +378,8 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
pixDestroy(pix);
|
pixDestroy(pix);
|
||||||
}
|
}
|
||||||
INDArray a = asMatrix(image);
|
INDArray a = asMatrix(image);
|
||||||
|
if(!nchw)
|
||||||
|
a = a.permute(0,2,3,1); //NCHW to NHWC
|
||||||
Image i = new Image(a, image.channels(), image.rows(), image.cols());
|
Image i = new Image(a, image.channels(), image.rows(), image.cols());
|
||||||
|
|
||||||
image.deallocate();
|
image.deallocate();
|
||||||
|
|
|
@ -16,10 +16,16 @@
|
||||||
|
|
||||||
package org.datavec.image.loader;
|
package org.datavec.image.loader;
|
||||||
|
|
||||||
|
import org.datavec.image.data.Image;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.common.resources.Resources;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.awt.image.BufferedImage;
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.io.BufferedInputStream;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.InputStream;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
@ -208,4 +214,57 @@ public class TestImageLoader {
|
||||||
private BufferedImage makeRandomBufferedImage(boolean alpha) {
|
private BufferedImage makeRandomBufferedImage(boolean alpha) {
|
||||||
return makeRandomBufferedImage(alpha, rng.nextInt() % 100 + 100, rng.nextInt() % 100 + 100);
|
return makeRandomBufferedImage(alpha, rng.nextInt() % 100 + 100, rng.nextInt() % 100 + 100);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNCHW_NHWC() throws Exception {
|
||||||
|
File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg");
|
||||||
|
|
||||||
|
ImageLoader il = new ImageLoader(32, 32, 3);
|
||||||
|
|
||||||
|
//asMatrix(File, boolean)
|
||||||
|
INDArray a_nchw = il.asMatrix(f);
|
||||||
|
INDArray a_nchw2 = il.asMatrix(f, true);
|
||||||
|
INDArray a_nhwc = il.asMatrix(f, false);
|
||||||
|
|
||||||
|
assertEquals(a_nchw, a_nchw2);
|
||||||
|
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||||
|
|
||||||
|
|
||||||
|
//asMatrix(InputStream, boolean)
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
a_nchw = il.asMatrix(is);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
a_nchw2 = il.asMatrix(is, true);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
a_nhwc = il.asMatrix(is, false);
|
||||||
|
}
|
||||||
|
assertEquals(a_nchw, a_nchw2);
|
||||||
|
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||||
|
|
||||||
|
|
||||||
|
//asImageMatrix(File, boolean)
|
||||||
|
Image i_nchw = il.asImageMatrix(f);
|
||||||
|
Image i_nchw2 = il.asImageMatrix(f, true);
|
||||||
|
Image i_nhwc = il.asImageMatrix(f, false);
|
||||||
|
|
||||||
|
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||||
|
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||||
|
|
||||||
|
|
||||||
|
//asImageMatrix(InputStream, boolean)
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
i_nchw = il.asImageMatrix(is);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
i_nchw2 = il.asImageMatrix(is, true);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
i_nhwc = il.asImageMatrix(is, false);
|
||||||
|
}
|
||||||
|
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||||
|
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,20 +24,19 @@ import org.bytedeco.javacpp.indexer.UByteIndexer;
|
||||||
import org.bytedeco.javacv.Frame;
|
import org.bytedeco.javacv.Frame;
|
||||||
import org.bytedeco.javacv.Java2DFrameConverter;
|
import org.bytedeco.javacv.Java2DFrameConverter;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
|
import org.datavec.image.data.Image;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.nd4j.common.resources.Resources;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.awt.image.BufferedImage;
|
import java.awt.image.BufferedImage;
|
||||||
import java.io.File;
|
import java.io.*;
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
|
@ -604,4 +603,56 @@ public class TestNativeImageLoader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNCHW_NHWC() throws Exception {
|
||||||
|
File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg");
|
||||||
|
|
||||||
|
NativeImageLoader il = new NativeImageLoader(32, 32, 3);
|
||||||
|
|
||||||
|
//asMatrix(File, boolean)
|
||||||
|
INDArray a_nchw = il.asMatrix(f);
|
||||||
|
INDArray a_nchw2 = il.asMatrix(f, true);
|
||||||
|
INDArray a_nhwc = il.asMatrix(f, false);
|
||||||
|
|
||||||
|
assertEquals(a_nchw, a_nchw2);
|
||||||
|
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||||
|
|
||||||
|
|
||||||
|
//asMatrix(InputStream, boolean)
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
a_nchw = il.asMatrix(is);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
a_nchw2 = il.asMatrix(is, true);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
a_nhwc = il.asMatrix(is, false);
|
||||||
|
}
|
||||||
|
assertEquals(a_nchw, a_nchw2);
|
||||||
|
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||||
|
|
||||||
|
|
||||||
|
//asImageMatrix(File, boolean)
|
||||||
|
Image i_nchw = il.asImageMatrix(f);
|
||||||
|
Image i_nchw2 = il.asImageMatrix(f, true);
|
||||||
|
Image i_nhwc = il.asImageMatrix(f, false);
|
||||||
|
|
||||||
|
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||||
|
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||||
|
|
||||||
|
|
||||||
|
//asImageMatrix(InputStream, boolean)
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
i_nchw = il.asImageMatrix(is);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
i_nchw2 = il.asImageMatrix(is, true);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
i_nhwc = il.asImageMatrix(is, false);
|
||||||
|
}
|
||||||
|
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||||
|
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -474,16 +474,18 @@ public class TestImageRecordReader {
|
||||||
public void testNCHW_NCHW() throws Exception {
|
public void testNCHW_NCHW() throws Exception {
|
||||||
//Idea: labels order should be consistent regardless of input file order
|
//Idea: labels order should be consistent regardless of input file order
|
||||||
File f0 = testDir.newFolder();
|
File f0 = testDir.newFolder();
|
||||||
Resources.copyDirectory("datavec-data-image/testimages/", f0);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
|
||||||
|
|
||||||
FileSplit fs = new FileSplit(f0, new Random(12345));
|
FileSplit fs0 = new FileSplit(f0, new Random(12345));
|
||||||
assertEquals(6, fs.locations().length);
|
FileSplit fs1 = new FileSplit(f0, new Random(12345));
|
||||||
|
assertEquals(6, fs0.locations().length);
|
||||||
|
assertEquals(6, fs1.locations().length);
|
||||||
|
|
||||||
ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true);
|
ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true);
|
||||||
nchw.initialize(fs);
|
nchw.initialize(fs0);
|
||||||
|
|
||||||
ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false);
|
ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false);
|
||||||
nhwc.initialize(fs);
|
nhwc.initialize(fs1);
|
||||||
|
|
||||||
while(nchw.hasNext()){
|
while(nchw.hasNext()){
|
||||||
assertTrue(nhwc.hasNext());
|
assertTrue(nhwc.hasNext());
|
||||||
|
@ -533,7 +535,7 @@ public class TestImageRecordReader {
|
||||||
|
|
||||||
//Test record(URI, DataInputStream)
|
//Test record(URI, DataInputStream)
|
||||||
|
|
||||||
URI u = fs.locations()[0];
|
URI u = fs0.locations()[0];
|
||||||
|
|
||||||
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
|
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
|
||||||
List<Writable> l = nchw.record(u, dis);
|
List<Writable> l = nchw.record(u, dis);
|
||||||
|
|
Loading…
Reference in New Issue