datavec-data-image test fixes

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2022-10-10 17:01:23 +02:00
parent 205252e5a9
commit a4bf1c3e62
22 changed files with 321 additions and 254 deletions

View File

@ -79,4 +79,13 @@ public class CollectionInputSplit extends BaseInputSplit {
return true; return true;
} }
/**
* Close input/ output streams if any
*/
@Override
public void close() {
}
} }

View File

@ -31,6 +31,7 @@ import org.nd4j.common.util.MathUtils;
import java.io.*; import java.io.*;
import java.net.URI; import java.net.URI;
import java.nio.file.Path;
import java.util.*; import java.util.*;
public class FileSplit extends BaseInputSplit { public class FileSplit extends BaseInputSplit {
@ -59,6 +60,10 @@ public class FileSplit extends BaseInputSplit {
this(rootDir, null, true, null, true); this(rootDir, null, true, null, true);
} }
public FileSplit(Path rootDir) {
this(rootDir.toFile(), null, true, null, true);
}
public FileSplit(File rootDir, Random rng) { public FileSplit(File rootDir, Random rng) {
this(rootDir, null, true, rng, true); this(rootDir, null, true, rng, true);
} }
@ -214,6 +219,14 @@ public class FileSplit extends BaseInputSplit {
return true; return true;
} }
/**
* Close input/ output streams if any
*/
@Override
public void close() {
}
public File getRootDir() { public File getRootDir() {
return rootDir; return rootDir;

View File

@ -133,4 +133,9 @@ public interface InputSplit {
* may throw an exception * may throw an exception
*/ */
boolean resetSupported(); boolean resetSupported();
/**
* Close input/ output streams if any
*/
void close();
} }

View File

@ -21,6 +21,7 @@
package org.datavec.api.split; package org.datavec.api.split;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.net.URI; import java.net.URI;
@ -149,6 +150,18 @@ public class InputStreamInputSplit implements InputSplit {
return location != null && location.length > 0; return location != null && location.length > 0;
} }
/**
* Close input/ output streams if any
*/
@Override
public void close() {
try {
is.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public InputStream getIs() { public InputStream getIs() {
return is; return is;

View File

@ -20,6 +20,7 @@
package org.datavec.api.split; package org.datavec.api.split;
import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.net.URI; import java.net.URI;
@ -124,4 +125,12 @@ public class ListStringSplit implements InputSplit {
public List<List<String>> getData() { public List<List<String>> getData() {
return data; return data;
} }
/**
* Close input/ output streams if any
*/
@Override
public void close() {
}
} }

View File

@ -153,6 +153,14 @@ public class NumberedFileInputSplit implements InputSplit {
return true; return true;
} }
/**
* Close input/ output streams if any
*/
@Override
public void close() {
}
private class NumberedFileIterator implements Iterator<String> { private class NumberedFileIterator implements Iterator<String> {
@ -179,5 +187,7 @@ public class NumberedFileInputSplit implements InputSplit {
public void remove() { public void remove() {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
} }
} }

View File

@ -23,6 +23,7 @@ package org.datavec.api.split;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.net.URI; import java.net.URI;
@ -115,5 +116,17 @@ public class OutputStreamInputSplit implements InputSplit {
return false; return false;
} }
/**
* Close input/ output streams if any
*/
@Override
public void close() {
try {
outputStream.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
} }

View File

@ -143,4 +143,12 @@ public class StreamInputSplit implements InputSplit {
public boolean resetSupported() { public boolean resetSupported() {
return true; return true;
} }
/**
* Close input/ output streams if any
*/
@Override
public void close() {
}
} }

View File

@ -107,6 +107,13 @@ public class StringSplit implements InputSplit {
return true; return true;
} }
/**
* Close input/ output streams if any
*/
@Override
public void close() {
}
public String getData() { public String getData() {

View File

@ -111,6 +111,15 @@ public class TransformSplit extends BaseInputSplit {
return true; return true;
} }
/**
* Close input/ output streams if any
*/
@Override
public void close() {
sourceSplit.close();
}
public interface URITransform { public interface URITransform {
URI apply(URI uri) throws URISyntaxException; URI apply(URI uri) throws URISyntaxException;
} }

View File

@ -99,8 +99,6 @@ public class ExcelRecordWriter extends FileRecordWriter {
partitioner.init(inputSplit); partitioner.init(inputSplit);
out = new DataOutputStream(partitioner.currentOutputStream()); out = new DataOutputStream(partitioner.currentOutputStream());
initPoi(); initPoi();
} }
private void initPoi() { private void initPoi() {

View File

@ -60,9 +60,6 @@ public class AndroidNativeImageLoader extends NativeImageLoader {
} }
public INDArray asMatrix(Bitmap image) throws IOException { public INDArray asMatrix(Bitmap image) throws IOException {
if (converter == null) {
converter = new OpenCVFrameConverter.ToMat();
}
return asMatrix(converter.convert(converter2.convert(image))); return asMatrix(converter.convert(converter2.convert(image)));
} }

View File

@ -20,6 +20,7 @@
package org.datavec.image.loader; package org.datavec.image.loader;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.datavec.image.data.Image; import org.datavec.image.data.Image;
@ -43,7 +44,10 @@ public abstract class BaseImageLoader implements Serializable {
} }
public static final File BASE_DIR = new File(System.getProperty("user.home")); public static final File BASE_DIR = new File(System.getProperty("user.home"));
@Getter
public static final String[] ALLOWED_FORMATS = {"tif", "jpg", "png", "jpeg", "bmp", "JPEG", "JPG", "TIF", "PNG"}; public static final String[] ALLOWED_FORMATS = {"tif", "jpg", "png", "jpeg", "bmp", "JPEG", "JPG", "TIF", "PNG"};
protected Random rng = new Random(System.currentTimeMillis()); protected Random rng = new Random(System.currentTimeMillis());
protected long height = -1; protected long height = -1;
@ -53,16 +57,17 @@ public abstract class BaseImageLoader implements Serializable {
protected ImageTransform imageTransform = null; protected ImageTransform imageTransform = null;
protected MultiPageMode multiPageMode = null; protected MultiPageMode multiPageMode = null;
public String[] getAllowedFormats() {
return ALLOWED_FORMATS;
}
public abstract INDArray asRowVector(File f) throws IOException; public abstract INDArray asRowVector(File f) throws IOException;
public abstract INDArray asRowVector(InputStream inputStream) throws IOException; public abstract INDArray asRowVector(InputStream inputStream) throws IOException;
/** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format */ /** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format.
public abstract INDArray asMatrix(File f) throws IOException; * Essentially calls asMatrix(File f, true)
*
**/
public INDArray asMatrix(File f) throws IOException {
return asMatrix( f, true);
}
/** /**
* Load an image from a file to an INDArray * Load an image from a file to an INDArray
@ -73,7 +78,15 @@ public abstract class BaseImageLoader implements Serializable {
*/ */
public abstract INDArray asMatrix(File f, boolean nchw) throws IOException; public abstract INDArray asMatrix(File f, boolean nchw) throws IOException;
public abstract INDArray asMatrix(InputStream inputStream) throws IOException; /**
* Load an image file from an input stream to an INDArray. Essentially calls asMatrix(inputStream, true)
* {@link #asMatrix(InputStream, boolean)} asMatrix
* @param inputStream Input stream to load the image from
* @return Image file stream as as INDArray NCHW/channels_first [1, channels, height width] format
*/
public INDArray asMatrix(InputStream inputStream) throws IOException {
return asMatrix(inputStream, true);
}
/** /**
* Load an image file from an input stream to an INDArray * Load an image file from an input stream to an INDArray
* @param inputStream Input stream to load the image from * @param inputStream Input stream to load the image from

View File

@ -257,16 +257,6 @@ public class ImageLoader extends BaseImageLoader {
} }
} }
/**
* Convert an input stream to a matrix
*
* @param inputStream the input stream to convert
* @return the input stream to convert
*/
public INDArray asMatrix(InputStream inputStream) throws IOException {
return asMatrix(inputStream, true);
}
@Override @Override
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException { public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
INDArray ret; INDArray ret;

View File

@ -85,9 +85,6 @@ public class Java2DNativeImageLoader extends NativeImageLoader {
* @throws IOException * @throws IOException
*/ */
public INDArray asMatrix(BufferedImage image, boolean flipChannels) throws IOException { public INDArray asMatrix(BufferedImage image, boolean flipChannels) throws IOException {
if (converter == null) {
converter = new OpenCVFrameConverter.ToMat();
}
return asMatrix(converter.convert(converter2.getFrame(image, 1.0, flipChannels))); return asMatrix(converter.convert(converter2.getFrame(image, 1.0, flipChannels)));
} }

View File

@ -20,6 +20,8 @@
package org.datavec.image.loader; package org.datavec.image.loader;
import lombok.Getter;
import lombok.NonNull;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*; import org.bytedeco.javacpp.indexer.*;
@ -38,6 +40,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.common.util.ArrayUtil; import org.nd4j.common.util.ArrayUtil;
import java.io.*; import java.io.*;
import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import org.bytedeco.leptonica.*; import org.bytedeco.leptonica.*;
@ -49,8 +52,9 @@ import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*; import static org.bytedeco.opencv.global.opencv_imgproc.*;
/** /**
* Uses JavaCV to load images. Allowed formats: bmp, gif, jpg, jpeg, jp2, pbm, pgm, ppm, pnm, png, tif, tiff, exr, webp * Uses JavaCV (that also wraps OpenCV) to load images.
* * Allowed formats: bmp, gif, jpg, jpeg, jp2, pbm, pgm, ppm, pnm, png, tif, tiff, exr, webp
* JavaCV supports a wider range of image formats compared to the {@link ImageLoader} variant.
* @author saudet * @author saudet
*/ */
public class NativeImageLoader extends BaseImageLoader { public class NativeImageLoader extends BaseImageLoader {
@ -58,12 +62,14 @@ public class NativeImageLoader extends BaseImageLoader {
private byte[] buffer = null; private byte[] buffer = null;
private Mat bufferMat = null; private Mat bufferMat = null;
@Getter
public static final String[] ALLOWED_FORMATS = {"bmp", "gif", "jpg", "jpeg", "jp2", "pbm", "pgm", "ppm", "pnm", public static final String[] ALLOWED_FORMATS = {"bmp", "gif", "jpg", "jpeg", "jp2", "pbm", "pgm", "ppm", "pnm",
"png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM", "png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM",
"PNG", "TIF", "TIFF", "EXR", "WEBP"}; "PNG", "TIF", "TIFF", "EXR", "WEBP"};
protected OpenCVFrameConverter.ToMat converter; protected final OpenCVFrameConverter.ToMat converter;
//Todo: Should be final, but TestNativeImageLoader uses this to simulate for Android
boolean direct = !Loader.getPlatform().startsWith("android"); boolean direct = !Loader.getPlatform().startsWith("android");
/** /**
@ -144,17 +150,9 @@ public class NativeImageLoader extends BaseImageLoader {
this.imageTransform = other.imageTransform; this.imageTransform = other.imageTransform;
} }
@Override
public String[] getAllowedFormats() {
return ALLOWED_FORMATS;
}
public INDArray asRowVector(String filename) throws IOException {
return asRowVector(new File(filename));
}
/** /**
* Convert a file to a row vector * Convert a file to a row vector by loading it into an {@link INDArray} and then
* calling flattening {@link INDArray#ravel()}.
* *
* @param f the image to convert * @param f the image to convert
* @return the flattened image * @return the flattened image
@ -164,7 +162,14 @@ public class NativeImageLoader extends BaseImageLoader {
public INDArray asRowVector(File f) throws IOException { public INDArray asRowVector(File f) throws IOException {
return asMatrix(f).ravel(); return asMatrix(f).ravel();
} }
/**
* Convert an input stream containing an image to a row vector by loading it into an {@link INDArray} and then
* calling flattening {@link INDArray#ravel()}.
*
* @param is the image input stream to convert
* @return the flattened image
* @throws IOException
*/
@Override @Override
public INDArray asRowVector(InputStream is) throws IOException { public INDArray asRowVector(InputStream is) throws IOException {
return asMatrix(is).ravel(); return asMatrix(is).ravel();
@ -192,7 +197,15 @@ public class NativeImageLoader extends BaseImageLoader {
return arr.reshape('c', 1, arr.length()); return arr.reshape('c', 1, arr.length());
} }
static Mat convert(PIX pix) { /**
* Helper method to convert a {@see http://leptonica.org Leptonica PIX} into an OpenCV Matrix.
* Leptonica is a pedagogically-oriented open source library containing software that is
* broadly useful for image processing and image analysis applications.
* @param pix the leptonica image format.
* @return OpenCV Matrix
*/
static Mat convert(@NonNull PIX pix) {
PIX tempPix = null; PIX tempPix = null;
int dtype = -1; int dtype = -1;
int height = pix.h(); int height = pix.h();
@ -245,54 +258,22 @@ public class NativeImageLoader extends BaseImageLoader {
return mat2; return mat2;
} }
public INDArray asMatrix(String filename) throws IOException {
return asMatrix(new File(filename));
}
@Override @Override
public INDArray asMatrix(File f) throws IOException { public INDArray asMatrix(@NonNull File f) throws IOException {
return asMatrix(f, true); return asMatrix(f, true);
} }
@Override @Override
public INDArray asMatrix(File f, boolean nchw) throws IOException { public INDArray asMatrix(@NonNull File f, boolean nchw) throws IOException {
Mat mat = imread(f.getAbsolutePath(), IMREAD_ANYDEPTH | IMREAD_ANYCOLOR ); return asMatrix(new FileInputStream(f), nchw);
INDArray a;
if (this.multiPageMode != null) {
a = asMatrix(mat.data(), mat.cols());
}else{
// Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (mat == null || mat.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols());
if (pix == null) {
throw new IOException("Could not decode image from input stream");
}
mat = convert(pix);
pixDestroy(pix);
}
a = asMatrix(mat);
mat.deallocate();
}
if(nchw) {
return a;
} else {
return a.permute(0, 2, 3, 1); //NCHW to NHWC
}
} }
@Override @Override
public INDArray asMatrix(InputStream is) throws IOException { public INDArray asMatrix(@NonNull InputStream inputStream, boolean nchw) throws IOException {
return asMatrix(is, true);
}
@Override
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
throw new RuntimeException("not implemented");
/*
Mat mat = streamToMat(inputStream); Mat mat = streamToMat(inputStream);
INDArray a; INDArray a;
if (this.multiPageMode != null) { if (this.multiPageMode != null) {
a = asMatrix(mat.data(), mat.cols()); a = asMatrix(mat.data(), mat.arraySize());
} else { } else {
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) { if (image == null || image.empty()) {
@ -311,8 +292,6 @@ public class NativeImageLoader extends BaseImageLoader {
} else { } else {
return a.permute(0, 2, 3, 1); //NCHW to NHWC return a.permute(0, 2, 3, 1); //NCHW to NHWC
} }
*/
} }
/** /**
@ -321,53 +300,13 @@ public class NativeImageLoader extends BaseImageLoader {
* @return Mat with the buffer data as a row vector * @return Mat with the buffer data as a row vector
* @throws IOException * @throws IOException
*/ */
private Mat streamToMat(InputStream is) throws IOException { private Mat streamToMat(@NonNull InputStream is) throws IOException {
if(buffer == null){ byte[] bytearray = IOUtils.toByteArray(is); //Todo: can be very large
buffer = IOUtils.toByteArray(is); if(bytearray == null || bytearray.length <= 0 ) {
if(buffer.length == 0){
throw new IOException("Could not decode image from input stream: input stream was empty (no data)"); throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
} }
bufferMat = new Mat(buffer); Mat outputMat = new Mat(bytearray);
return bufferMat; return outputMat;
} else {
int numReadTotal = is.read(buffer);
//Need to know if all data has been read.
//(a) if numRead < buffer.length - got everything
//(b) if numRead >= buffer.length: we MIGHT have got everything (exact right size buffer) OR we need more data
if(numReadTotal <= 0){
throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
}
if(numReadTotal < buffer.length){
bufferMat.data().put(buffer, 0, numReadTotal);
bufferMat.cols(numReadTotal);
return bufferMat;
}
//Buffer is full; reallocate and keep reading
int numReadCurrent = numReadTotal;
while(numReadCurrent != -1){
byte[] oldBuffer = buffer;
if(oldBuffer.length == Integer.MAX_VALUE){
throw new IllegalStateException("Cannot read more than Integer.MAX_VALUE bytes");
}
//Double buffer, but allocate at least 1MB more
long increase = Math.max(buffer.length, MIN_BUFFER_STEP_SIZE);
int newBufferLength = (int)Math.min(Integer.MAX_VALUE, buffer.length + increase);
buffer = new byte[newBufferLength];
System.arraycopy(oldBuffer, 0, buffer, 0, oldBuffer.length);
numReadCurrent = is.read(buffer, oldBuffer.length, buffer.length - oldBuffer.length);
if(numReadCurrent > 0){
numReadTotal += numReadCurrent;
}
}
bufferMat = new Mat(buffer);
return bufferMat;
}
} }
public Image asImageMatrix(String filename) throws IOException { public Image asImageMatrix(String filename) throws IOException {
@ -624,7 +563,6 @@ public class NativeImageLoader extends BaseImageLoader {
public INDArray asMatrix(Mat image) throws IOException { public INDArray asMatrix(Mat image) throws IOException {
INDArray ret = transformImage(image, null); INDArray ret = transformImage(image, null);
return ret.reshape(ArrayUtil.combine(new long[] {1}, ret.shape())); return ret.reshape(ArrayUtil.combine(new long[] {1}, ret.shape()));
} }
@ -678,6 +616,7 @@ public class NativeImageLoader extends BaseImageLoader {
throw new IOException("Cannot convert from " + image.channels() + " to " + channels + " channels."); throw new IOException("Cannot convert from " + image.channels() + " to " + channels + " channels.");
} }
image2 = new Mat(); image2 = new Mat();
if(image.rows() == 0 && image.cols() == 0) throw new RuntimeException("Cannot convert image with source dimensions 0x0");
cvtColor(image, image2, code); cvtColor(image, image2, code);
image = image2; image = image2;
} }
@ -895,9 +834,12 @@ public class NativeImageLoader extends BaseImageLoader {
* @return INDArray * @return INDArray
* @throws IOException * @throws IOException
*/ */
private INDArray asMatrix(BytePointer bytes, long length) throws IOException { private INDArray asMatrix(@NonNull BytePointer bytes, long length) throws IOException {
PIXA pixa; //This is an array of PIX (due to multipage)
pixa = pixaReadMemMultipageTiff(bytes, length); PIXA pixa = pixaReadMemMultipageTiff(bytes, length);
if(pixa == null) throw new RuntimeException("Error reading multipage PIX");
INDArray data; INDArray data;
INDArray currentD; INDArray currentD;
INDArrayIndex[] index = null; INDArrayIndex[] index = null;

View File

@ -122,7 +122,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
} }
protected boolean containsFormat(String format) { protected boolean containsFormat(String format) {
for (String format2 : imageLoader.getAllowedFormats()) for (String format2 : imageLoader.getALLOWED_FORMATS())
if (format.endsWith("." + format2)) if (format.endsWith("." + format2))
return true; return true;
return false; return false;
@ -172,6 +172,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
//remove the root directory //remove the root directory
FileSplit split1 = (FileSplit) split; FileSplit split1 = (FileSplit) split;
labels.remove(split1.getRootDir()); labels.remove(split1.getRootDir());
split1.close();
} }
//To ensure consistent order for label assignment (irrespective of file iteration order), we want to sort the list of labels //To ensure consistent order for label assignment (irrespective of file iteration order), we want to sort the list of labels
@ -405,7 +406,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
@Override @Override
public void close() throws IOException { public void close() throws IOException {
//No op this.inputSplit.close();
} }
@Override @Override

View File

@ -272,6 +272,30 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
} }
} }
public List<Writable> record(URI uri, File f) throws IOException {
invokeListeners(uri);
if (imageLoader == null) {
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
}
Image image = this.imageLoader.asImageMatrix(f);
if(!nchw)
image.setImage(image.getImage().permute(0,2,3,1));
Nd4j.getAffinityManager().ensureLocation(image.getImage(), AffinityManager.Location.DEVICE);
List<Writable> ret = RecordConverter.toRecord(image.getImage());
if (appendLabel) {
List<ImageObject> imageObjectsForPath = labelProvider.getImageObjectsForPath(uri.getPath());
int nClasses = labels.size();
INDArray outLabel = Nd4j.create(1, 4 + nClasses, gridH, gridW);
label(image, imageObjectsForPath, outLabel, 0);
if(!nchw)
outLabel = outLabel.permute(0,2,3,1); //NCHW to NHWC
ret.add(new NDArrayWritable(outLabel));
}
return ret;
}
@Override @Override
public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException { public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
invokeListeners(uri); invokeListeners(uri);

View File

@ -21,14 +21,24 @@
package org.datavec.image; package org.datavec.image;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.image.recordreader.ImageRecordReader;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.CleanupMode;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.nio.file.FileVisitResult;
import java.nio.file.FileVisitor;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -37,32 +47,28 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class LabelGeneratorTest { public class LabelGeneratorTest {
@TempDir
public File testDir;
@Test @ParameterizedTest
public void testParentPathLabelGenerator() throws Exception { @ValueSource(strings = {"m", "m.", "something"})
public void testParentPathLabelGenerator(String dirPrefix, @TempDir Path testDir) throws Exception {
//https://github.com/deeplearning4j/DataVec/issues/273 //https://github.com/deeplearning4j/DataVec/issues/273
File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile(); File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile();
for(String dirPrefix : new String[]{"m.", "m"}) {
File f = testDir;
int numDirs = 3; int numDirs = 3;
int filesPerDir = 4; int filesPerDir = 4;
for (int i = 0; i < numDirs; i++) { for (int i = 0; i < numDirs; i++) {
File currentLabelDir = new File(f, dirPrefix + i); File currentLabelDir = new File(testDir.toFile(), dirPrefix + i);
currentLabelDir.mkdirs();
for (int j = 0; j < filesPerDir; j++) { for (int j = 0; j < filesPerDir; j++) {
File f3 = new File(currentLabelDir, "myImg_" + j + ".jpg"); File f3 = new File(currentLabelDir, "myImg_" + j + ".jpg");
FileUtils.copyFile(orig, f3); FileUtils.copyFile(orig, f3); //will create directories as needed
assertTrue(f3.exists()); assertTrue(f3.exists());
} }
} }
ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator()); ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
rr.initialize(new FileSplit(f)); rr.initialize(new FileSplit(testDir));
List<String> labelsAct = rr.getLabels(); List<String> labelsAct = rr.getLabels();
List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2"); List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2");
@ -74,7 +80,9 @@ public class LabelGeneratorTest {
rr.next(); rr.next();
actCount++; actCount++;
} }
rr.close();
assertEquals(expCount, actCount); assertEquals(expCount, actCount);
}
} }
} }

View File

@ -28,29 +28,36 @@ import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.Java2DFrameConverter; import org.bytedeco.javacv.Java2DFrameConverter;
import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter;
import org.bytedeco.leptonica.PIX;
import org.bytedeco.leptonica.PIXCMAP;
import org.bytedeco.opencv.opencv_core.Mat;
import org.datavec.image.data.Image; import org.datavec.image.data.Image;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
import java.awt.image.BufferedImage; import java.awt.image.BufferedImage;
import java.io.*; import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.nio.file.Path;
import java.util.Random; import java.util.Random;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.bytedeco.leptonica.*;
import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.leptonica.global.lept.*; import static org.bytedeco.leptonica.global.lept.*;
import static org.bytedeco.opencv.global.opencv_core.*; import static org.bytedeco.opencv.global.opencv_core.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
/** /**
* *
@ -61,73 +68,74 @@ public class TestNativeImageLoader {
static final long seed = 10; static final long seed = 10;
static final Random rng = new Random(seed); static final Random rng = new Random(seed);
@TempDir @ParameterizedTest(name = "#{index} - Run test with arguments WxHxD {0}, {1}, {2}")
public File testDir; @MethodSource("generateDimensions")
public void testConvertPIX(int width, int height, int depth, int matType) {
PIX pix = pixCreate(width, height, depth);
Mat mat = NativeImageLoader.convert(pix);
assertEquals(width, mat.cols());
assertEquals(height, mat.rows());
if(matType==CV_8UC4) matType= CV_8UC1; //this would be for 8 bit 256 gradients
assertEquals(matType, mat.type());
}
/**
* Run PIX creation and conversation test with ColorMap
* @param width
* @param height
* @param depth
* @param matType
*/
@ParameterizedTest(name = "#{index} - Run test with arguments WxHxD {0}, {1}, {2}")
@MethodSource("generateDimensions")
public void testConvertPIXCMAP(int width, int height, int depth, int matType) {
// a GIF file, for example
PIX pix = pixCreate(width, height, depth);
PIXCMAP cmap = pixcmapCreateLinear(depth, 256);
pixSetColormap(pix, cmap);
Mat mat = NativeImageLoader.convert(pix);
assertEquals(width, mat.cols());
assertEquals(height, mat.rows());
if( matType == CV_8UC1 && depth >= 8 ) matType = CV_8UC4; //change the argument, as this is depth 8 with 256 shades
assertEquals(matType, mat.type());
}
// Static stream of arguments
static Stream<Arguments> generateDimensions() {
return Stream.of(
Arguments.arguments(20, 20 ,1, CV_8UC1),
Arguments.arguments(1, 1, 1, CV_8UC1),
Arguments.arguments(1014, 1080, 1, CV_8UC1),
Arguments.arguments(20, 20 ,2, CV_8UC1),
Arguments.arguments(1, 1, 2, CV_8UC1),
Arguments.arguments(1014, 1080, 2, CV_8UC1),
Arguments.arguments(20, 20 ,4, CV_8UC1),
Arguments.arguments(1, 1, 4, CV_8UC1),
Arguments.arguments(1014, 1080, 4, CV_8UC1),
Arguments.arguments(20, 20 ,8, CV_8UC1),
Arguments.arguments(1, 1, 8, CV_8UC1),
Arguments.arguments(1014, 1080, 16, CV_16UC(1)),
Arguments.arguments(1014, 1080, 24, CV_8UC(3)),
Arguments.arguments(1014, 1080, 32, CV_32FC1),
Arguments.arguments(2048, 4096, 32, CV_32FC1),
Arguments.arguments(300, 300, 8, CV_8UC4)
);
}
@Test @Test
public void testConvertPix() throws Exception { public void testConvertPix() throws Exception {
PIX pix;
Mat mat;
pix = pixCreate(11, 22, 1);
mat = NativeImageLoader.convert(pix);
assertEquals(11, mat.cols());
assertEquals(22, mat.rows());
assertEquals(CV_8UC1, mat.type());
pix = pixCreate(33, 44, 2);
mat = NativeImageLoader.convert(pix);
assertEquals(33, mat.cols());
assertEquals(44, mat.rows());
assertEquals(CV_8UC1, mat.type());
pix = pixCreate(55, 66, 4);
mat = NativeImageLoader.convert(pix);
assertEquals(55, mat.cols());
assertEquals(66, mat.rows());
assertEquals(CV_8UC1, mat.type());
pix = pixCreate(77, 88, 8);
mat = NativeImageLoader.convert(pix);
assertEquals(77, mat.cols());
assertEquals(88, mat.rows());
assertEquals(CV_8UC1, mat.type());
pix = pixCreate(99, 111, 16);
mat = NativeImageLoader.convert(pix);
assertEquals(99, mat.cols());
assertEquals(111, mat.rows());
assertEquals(CV_16UC(1), mat.type());
pix = pixCreate(222, 333, 24);
mat = NativeImageLoader.convert(pix);
assertEquals(222, mat.cols());
assertEquals(333, mat.rows());
assertEquals(CV_8UC(3), mat.type());
pix = pixCreate(444, 555, 32);
mat = NativeImageLoader.convert(pix);
assertEquals(444, mat.cols());
assertEquals(555, mat.rows());
assertEquals(CV_32FC1, mat.type());
// a GIF file, for example
pix = pixCreate(32, 32, 8);
PIXCMAP cmap = pixcmapCreateLinear(8, 256);
pixSetColormap(pix, cmap);
mat = NativeImageLoader.convert(pix);
assertEquals(32, mat.cols());
assertEquals(32, mat.rows());
assertEquals(CV_8UC4, mat.type());
int w4 = 100, h4 = 238, ch4 = 1, pages = 1, depth = 1; int w4 = 100, h4 = 238, ch4 = 1, pages = 1, depth = 1;
String path2MitosisFile = "datavec-data-image/testimages2/mitosis.tif"; String path2MitosisFile = "datavec-data-image/testimages2/mitosis.tif";
NativeImageLoader loader5 = new NativeImageLoader(h4, w4, ch4, NativeImageLoader.MultiPageMode.FIRST); NativeImageLoader loader5 = new NativeImageLoader(h4, w4, ch4, NativeImageLoader.MultiPageMode.FIRST);
INDArray array6 = null; INDArray array6 = null;
try { try {
array6 = loader5.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath()); File f = new ClassPathResource(path2MitosisFile).getFile();
assertTrue(!f.isDirectory() && f.canRead());
array6 = loader5.asMatrix( f );
} catch (IOException e) { } catch (IOException e) {
log.error("",e); System.out.println(e.getMessage());
fail(); fail();
} }
assertEquals(5, array6.rank()); assertEquals(5, array6.rank());
@ -158,7 +166,7 @@ public class TestNativeImageLoader {
NativeImageLoader loader7 = new NativeImageLoader(h4, w4, ch6, NativeImageLoader.MultiPageMode.MINIBATCH); NativeImageLoader loader7 = new NativeImageLoader(h4, w4, ch6, NativeImageLoader.MultiPageMode.MINIBATCH);
INDArray array8 = null; INDArray array8 = null;
try { try {
array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath()); array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile());
} catch (IOException e) { } catch (IOException e) {
log.error("",e); log.error("",e);
} }
@ -174,7 +182,7 @@ public class TestNativeImageLoader {
NativeImageLoader loader8 = new NativeImageLoader(h5, w5, ch6, NativeImageLoader.MultiPageMode.MINIBATCH); NativeImageLoader loader8 = new NativeImageLoader(h5, w5, ch6, NativeImageLoader.MultiPageMode.MINIBATCH);
INDArray array9 = null; INDArray array9 = null;
try { try {
array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile().getAbsolutePath()); array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile());
} catch (IOException e) { } catch (IOException e) {
log.error("",e); log.error("",e);
fail(); fail();
@ -481,7 +489,7 @@ public class TestNativeImageLoader {
int w1 = 33, h1 = 77, ch1 = 1; int w1 = 33, h1 = 77, ch1 = 1;
NativeImageLoader loader1 = new NativeImageLoader(h1, w1, ch1); NativeImageLoader loader1 = new NativeImageLoader(h1, w1, ch1);
INDArray array1 = loader1.asMatrix(f0); INDArray array1 = loader1.asMatrix(new File(f0));
assertEquals(4, array1.rank()); assertEquals(4, array1.rank());
assertEquals(1, array1.size(0)); assertEquals(1, array1.size(0));
assertEquals(1, array1.size(1)); assertEquals(1, array1.size(1));
@ -565,40 +573,43 @@ public class TestNativeImageLoader {
@Test @Test
public void testNativeImageLoaderEmptyStreams() throws Exception { public void testNativeImageLoaderEmptyStreams(@TempDir Path tempDir) throws Exception {
File dir = testDir; File dir = tempDir.toFile();
File f = new File(dir, "myFile.jpg"); File f = new File(dir, "myFile.jpg");
f.createNewFile(); f.createNewFile();
NativeImageLoader nil = new NativeImageLoader(32, 32, 3); NativeImageLoader nil = new NativeImageLoader(32, 32, 3);
try(InputStream is = new FileInputStream(f)){ try {
nil.asMatrix(is); nil.asMatrix(f);
fail("Expected exception"); fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg.contains("decode image"), msg); assertTrue(msg.contains("decode image"), msg);
} }
try(InputStream is = new FileInputStream(f)){ try {
nil.asImageMatrix(is); nil.asImageMatrix(f);
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg.contains("decode image"), msg);
} catch (RuntimeException e) {
String msg = e.getMessage();
assertTrue(msg.contains("Cannot convert image with source dimensions 0x0"));
}
try{
nil.asRowVector(f);
fail("Expected exception"); fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg.contains("decode image"), msg); assertTrue(msg.contains("decode image"), msg);
} }
try(InputStream is = new FileInputStream(f)){ try{
nil.asRowVector(is);
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg.contains("decode image"), msg);
}
try(InputStream is = new FileInputStream(f)){
INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32); INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32);
nil.asMatrixView(is, arr); nil.asMatrixView(f, arr);
fail("Expected exception"); fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();

View File

@ -20,6 +20,8 @@
package org.nd4j.common.io; package org.nd4j.common.io;
import lombok.Getter;
import lombok.NonNull;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
@ -38,43 +40,32 @@ import java.util.zip.ZipFile;
public class ClassPathResource extends AbstractFileResolvingResource { public class ClassPathResource extends AbstractFileResolvingResource {
private final String path; @Getter private final String path;
private ClassLoader classLoader; @Getter private final ClassLoader classLoader;
private Class<?> clazz; private Class<?> clazz;
public ClassPathResource(String path) { public ClassPathResource(@NonNull String path) {
this(path, (ClassLoader) null); this(path, ND4JClassLoading.getNd4jClassloader());
} }
public ClassPathResource(String path, ClassLoader classLoader) { public ClassPathResource(@NonNull String path, @NonNull ClassLoader classLoader) {
Assert.notNull(path, "Path must not be null");
String pathToUse = StringUtils.cleanPath(path); String pathToUse = StringUtils.cleanPath(path);
if (pathToUse.startsWith("/")) { if (pathToUse.startsWith("/")) {
pathToUse = pathToUse.substring(1); pathToUse = pathToUse.substring(1);
} }
this.path = pathToUse; this.path = pathToUse;
this.classLoader = classLoader != null ? classLoader : ND4JClassLoading.getNd4jClassloader();
}
public ClassPathResource(String path, Class<?> clazz) {
Assert.notNull(path, "Path must not be null");
this.path = StringUtils.cleanPath(path);
this.clazz = clazz;
}
protected ClassPathResource(String path, ClassLoader classLoader, Class<?> clazz) {
this.path = StringUtils.cleanPath(path);
this.classLoader = classLoader; this.classLoader = classLoader;
}
public ClassPathResource(@NonNull String path, @NonNull Class<?> clazz) {
this(path, clazz.getClassLoader());
this.clazz = clazz; this.clazz = clazz;
} }
public final String getPath() { protected ClassPathResource(@NonNull String path, @NonNull ClassLoader classLoader, @NonNull Class<?> clazz) {
return this.path; this(path, classLoader);
} this.clazz = clazz;
public final ClassLoader getClassLoader() {
return this.classLoader != null ? this.classLoader : this.clazz.getClassLoader();
} }
/** /**
@ -133,14 +124,12 @@ public class ClassPathResource extends AbstractFileResolvingResource {
} else { } else {
tmpFile = Files.createTempFile(FilenameUtils.getName(path), "tmp").toFile(); tmpFile = Files.createTempFile(FilenameUtils.getName(path), "tmp").toFile();
} }
tmpFile.deleteOnExit();
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(tmpFile)); BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(tmpFile));
IOUtils.copy(is, bos); IOUtils.copy(is, bos);
bos.flush(); bos.flush();
bos.close(); bos.close();
is.close();
return tmpFile; return tmpFile;
} }

View File

@ -213,6 +213,7 @@ chipList.each { thisChip ->
if (project.hasProperty("skip-native") && project.getProperty("skip-native").equals("true")) { if (project.hasProperty("skip-native") && project.getProperty("skip-native").equals("true")) {
enabled = false enabled = false
} }
dependsOn "processResources"
properties = getBuildPlatform( thisChip, it ) properties = getBuildPlatform( thisChip, it )