datavec-data-image test fixes

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2022-10-10 17:01:23 +02:00
parent 6cb5d30284
commit c46e6e4c68
24 changed files with 332 additions and 264 deletions

View File

@ -72,12 +72,16 @@ allprojects { Project proj ->
testAnnotationProcessor platform(project(":cavis-common-platform"))
testImplementation platform(project(":cavis-common-platform"))
compileOnly 'org.projectlombok:lombok'
annotationProcessor 'org.projectlombok:lombok'
testCompileOnly 'org.projectlombok:lombok'
testAnnotationProcessor 'org.projectlombok:lombok'
testImplementation 'org.junit.jupiter:junit-jupiter-engine'
testImplementation 'org.junit.jupiter:junit-jupiter-api'
compileOnly 'org.projectlombok:lombok'
annotationProcessor 'org.projectlombok:lombok'
testCompileOnly 'org.projectlombok:lombok'
testAnnotationProcessor 'org.projectlombok:lombok'
testImplementation 'org.junit.jupiter:junit-jupiter-engine'
testImplementation 'org.junit.jupiter:junit-jupiter-api'
testImplementation 'org.junit.jupiter:junit-jupiter-params'
implementation "org.slf4j:slf4j-api"
implementation "org.slf4j:slf4j-simple"
}
test {

View File

@ -28,10 +28,7 @@ import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable;
import java.io.Closeable;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.Serializable;
import java.io.*;
import java.net.URI;
import java.util.Collection;
import java.util.List;

View File

@ -79,4 +79,13 @@ public class CollectionInputSplit extends BaseInputSplit {
return true;
}
}
/**
* Close input/ output streams if any
*/
@Override
public void close() {
}
}

View File

@ -31,6 +31,7 @@ import org.nd4j.common.util.MathUtils;
import java.io.*;
import java.net.URI;
import java.nio.file.Path;
import java.util.*;
public class FileSplit extends BaseInputSplit {
@ -59,6 +60,10 @@ public class FileSplit extends BaseInputSplit {
this(rootDir, null, true, null, true);
}
public FileSplit(Path rootDir) {
this(rootDir.toFile(), null, true, null, true);
}
public FileSplit(File rootDir, Random rng) {
this(rootDir, null, true, rng, true);
}
@ -214,6 +219,14 @@ public class FileSplit extends BaseInputSplit {
return true;
}
/**
* Close input/ output streams if any
*/
@Override
public void close() {
}
public File getRootDir() {
return rootDir;

View File

@ -133,4 +133,9 @@ public interface InputSplit {
* may throw an exception
*/
boolean resetSupported();
/**
* Close input/ output streams if any
*/
void close();
}

View File

@ -21,6 +21,7 @@
package org.datavec.api.split;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
@ -149,6 +150,18 @@ public class InputStreamInputSplit implements InputSplit {
return location != null && location.length > 0;
}
/**
* Close input/ output streams if any
*/
@Override
public void close() {
try {
is.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public InputStream getIs() {
return is;

View File

@ -20,6 +20,7 @@
package org.datavec.api.split;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
@ -124,4 +125,12 @@ public class ListStringSplit implements InputSplit {
public List<List<String>> getData() {
return data;
}
/**
* Close input/ output streams if any
*/
@Override
public void close() {
}
}

View File

@ -153,6 +153,14 @@ public class NumberedFileInputSplit implements InputSplit {
return true;
}
/**
* Close input/ output streams if any
*/
@Override
public void close() {
}
private class NumberedFileIterator implements Iterator<String> {
@ -179,5 +187,7 @@ public class NumberedFileInputSplit implements InputSplit {
public void remove() {
throw new UnsupportedOperationException();
}
}
}

View File

@ -23,6 +23,7 @@ package org.datavec.api.split;
import lombok.Getter;
import lombok.Setter;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
@ -115,5 +116,17 @@ public class OutputStreamInputSplit implements InputSplit {
return false;
}
/**
* Close input/ output streams if any
*/
@Override
public void close() {
try {
outputStream.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -143,4 +143,12 @@ public class StreamInputSplit implements InputSplit {
public boolean resetSupported() {
return true;
}
/**
* Close input/ output streams if any
*/
@Override
public void close() {
}
}

View File

@ -107,6 +107,13 @@ public class StringSplit implements InputSplit {
return true;
}
/**
* Close input/ output streams if any
*/
@Override
public void close() {
}
public String getData() {

View File

@ -111,6 +111,15 @@ public class TransformSplit extends BaseInputSplit {
return true;
}
/**
* Close input/ output streams if any
*/
@Override
public void close() {
sourceSplit.close();
}
public interface URITransform {
URI apply(URI uri) throws URISyntaxException;
}

View File

@ -99,8 +99,6 @@ public class ExcelRecordWriter extends FileRecordWriter {
partitioner.init(inputSplit);
out = new DataOutputStream(partitioner.currentOutputStream());
initPoi();
}
private void initPoi() {

View File

@ -60,9 +60,6 @@ public class AndroidNativeImageLoader extends NativeImageLoader {
}
public INDArray asMatrix(Bitmap image) throws IOException {
if (converter == null) {
converter = new OpenCVFrameConverter.ToMat();
}
return asMatrix(converter.convert(converter2.convert(image)));
}

View File

@ -20,6 +20,7 @@
package org.datavec.image.loader;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.datavec.image.data.Image;
@ -43,7 +44,10 @@ public abstract class BaseImageLoader implements Serializable {
}
public static final File BASE_DIR = new File(System.getProperty("user.home"));
@Getter
public static final String[] ALLOWED_FORMATS = {"tif", "jpg", "png", "jpeg", "bmp", "JPEG", "JPG", "TIF", "PNG"};
protected Random rng = new Random(System.currentTimeMillis());
protected long height = -1;
@ -53,16 +57,17 @@ public abstract class BaseImageLoader implements Serializable {
protected ImageTransform imageTransform = null;
protected MultiPageMode multiPageMode = null;
public String[] getAllowedFormats() {
return ALLOWED_FORMATS;
}
public abstract INDArray asRowVector(File f) throws IOException;
public abstract INDArray asRowVector(InputStream inputStream) throws IOException;
/** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format */
public abstract INDArray asMatrix(File f) throws IOException;
/** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format.
* Essentially calls asMatrix(File f, true)
*
**/
public INDArray asMatrix(File f) throws IOException {
return asMatrix( f, true);
}
/**
* Load an image from a file to an INDArray
@ -73,7 +78,15 @@ public abstract class BaseImageLoader implements Serializable {
*/
public abstract INDArray asMatrix(File f, boolean nchw) throws IOException;
public abstract INDArray asMatrix(InputStream inputStream) throws IOException;
/**
* Load an image file from an input stream to an INDArray. Essentially calls asMatrix(inputStream, true)
* {@link #asMatrix(InputStream, boolean)} asMatrix
* @param inputStream Input stream to load the image from
* @return Image file stream as as INDArray NCHW/channels_first [1, channels, height width] format
*/
public INDArray asMatrix(InputStream inputStream) throws IOException {
return asMatrix(inputStream, true);
}
/**
* Load an image file from an input stream to an INDArray
* @param inputStream Input stream to load the image from

View File

@ -257,16 +257,6 @@ public class ImageLoader extends BaseImageLoader {
}
}
/**
* Convert an input stream to a matrix
*
* @param inputStream the input stream to convert
* @return the input stream to convert
*/
public INDArray asMatrix(InputStream inputStream) throws IOException {
return asMatrix(inputStream, true);
}
@Override
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
INDArray ret;

View File

@ -85,9 +85,6 @@ public class Java2DNativeImageLoader extends NativeImageLoader {
* @throws IOException
*/
public INDArray asMatrix(BufferedImage image, boolean flipChannels) throws IOException {
if (converter == null) {
converter = new OpenCVFrameConverter.ToMat();
}
return asMatrix(converter.convert(converter2.getFrame(image, 1.0, flipChannels)));
}

View File

@ -20,6 +20,8 @@
package org.datavec.image.loader;
import lombok.Getter;
import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*;
@ -38,6 +40,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.common.util.ArrayUtil;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import org.bytedeco.leptonica.*;
@ -49,8 +52,9 @@ import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*;
/**
* Uses JavaCV to load images. Allowed formats: bmp, gif, jpg, jpeg, jp2, pbm, pgm, ppm, pnm, png, tif, tiff, exr, webp
*
* Uses JavaCV (that also wraps OpenCV) to load images.
* Allowed formats: bmp, gif, jpg, jpeg, jp2, pbm, pgm, ppm, pnm, png, tif, tiff, exr, webp
* JavaCV supports a wider range of image formats compared to the {@link ImageLoader} variant.
* @author saudet
*/
public class NativeImageLoader extends BaseImageLoader {
@ -58,12 +62,14 @@ public class NativeImageLoader extends BaseImageLoader {
private byte[] buffer = null;
private Mat bufferMat = null;
@Getter
public static final String[] ALLOWED_FORMATS = {"bmp", "gif", "jpg", "jpeg", "jp2", "pbm", "pgm", "ppm", "pnm",
"png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM",
"PNG", "TIF", "TIFF", "EXR", "WEBP"};
protected OpenCVFrameConverter.ToMat converter;
protected final OpenCVFrameConverter.ToMat converter;
//Todo: Should be final, but TestNativeImageLoader uses this to simulate for Android
boolean direct = !Loader.getPlatform().startsWith("android");
/**
@ -144,17 +150,9 @@ public class NativeImageLoader extends BaseImageLoader {
this.imageTransform = other.imageTransform;
}
@Override
public String[] getAllowedFormats() {
return ALLOWED_FORMATS;
}
public INDArray asRowVector(String filename) throws IOException {
return asRowVector(new File(filename));
}
/**
* Convert a file to a row vector
* Convert a file to a row vector by loading it into an {@link INDArray} and then
* calling flattening {@link INDArray#ravel()}.
*
* @param f the image to convert
* @return the flattened image
@ -164,7 +162,14 @@ public class NativeImageLoader extends BaseImageLoader {
public INDArray asRowVector(File f) throws IOException {
return asMatrix(f).ravel();
}
/**
* Convert an input stream containing an image to a row vector by loading it into an {@link INDArray} and then
* calling flattening {@link INDArray#ravel()}.
*
* @param is the image input stream to convert
* @return the flattened image
* @throws IOException
*/
@Override
public INDArray asRowVector(InputStream is) throws IOException {
return asMatrix(is).ravel();
@ -192,7 +197,15 @@ public class NativeImageLoader extends BaseImageLoader {
return arr.reshape('c', 1, arr.length());
}
static Mat convert(PIX pix) {
/**
* Helper method to convert a {@see http://leptonica.org Leptonica PIX} into an OpenCV Matrix.
* Leptonica is a pedagogically-oriented open source library containing software that is
* broadly useful for image processing and image analysis applications.
* @param pix the leptonica image format.
* @return OpenCV Matrix
*/
static Mat convert(@NonNull PIX pix) {
PIX tempPix = null;
int dtype = -1;
int height = pix.h();
@ -245,55 +258,23 @@ public class NativeImageLoader extends BaseImageLoader {
return mat2;
}
public INDArray asMatrix(String filename) throws IOException {
return asMatrix(new File(filename));
}
@Override
public INDArray asMatrix(File f) throws IOException {
public INDArray asMatrix(@NonNull File f) throws IOException {
return asMatrix(f, true);
}
@Override
public INDArray asMatrix(File f, boolean nchw) throws IOException {
Mat mat = imread(f.getAbsolutePath(), IMREAD_ANYDEPTH | IMREAD_ANYCOLOR );
INDArray a;
if (this.multiPageMode != null) {
a = asMatrix(mat.data(), mat.cols());
}else{
// Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (mat == null || mat.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols());
if (pix == null) {
throw new IOException("Could not decode image from input stream");
}
mat = convert(pix);
pixDestroy(pix);
}
a = asMatrix(mat);
mat.deallocate();
}
if(nchw) {
return a;
} else {
return a.permute(0, 2, 3, 1); //NCHW to NHWC
}
public INDArray asMatrix(@NonNull File f, boolean nchw) throws IOException {
return asMatrix(new FileInputStream(f), nchw);
}
@Override
public INDArray asMatrix(InputStream is) throws IOException {
return asMatrix(is, true);
}
@Override
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
throw new RuntimeException("not implemented");
/*
public INDArray asMatrix(@NonNull InputStream inputStream, boolean nchw) throws IOException {
Mat mat = streamToMat(inputStream);
INDArray a;
if (this.multiPageMode != null) {
a = asMatrix(mat.data(), mat.cols());
}else{
a = asMatrix(mat.data(), mat.arraySize());
} else {
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols());
@ -311,8 +292,6 @@ public class NativeImageLoader extends BaseImageLoader {
} else {
return a.permute(0, 2, 3, 1); //NCHW to NHWC
}
*/
}
/**
@ -321,53 +300,13 @@ public class NativeImageLoader extends BaseImageLoader {
* @return Mat with the buffer data as a row vector
* @throws IOException
*/
private Mat streamToMat(InputStream is) throws IOException {
if(buffer == null){
buffer = IOUtils.toByteArray(is);
if(buffer.length == 0){
throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
}
bufferMat = new Mat(buffer);
return bufferMat;
} else {
int numReadTotal = is.read(buffer);
//Need to know if all data has been read.
//(a) if numRead < buffer.length - got everything
//(b) if numRead >= buffer.length: we MIGHT have got everything (exact right size buffer) OR we need more data
if(numReadTotal <= 0){
throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
}
if(numReadTotal < buffer.length){
bufferMat.data().put(buffer, 0, numReadTotal);
bufferMat.cols(numReadTotal);
return bufferMat;
}
//Buffer is full; reallocate and keep reading
int numReadCurrent = numReadTotal;
while(numReadCurrent != -1){
byte[] oldBuffer = buffer;
if(oldBuffer.length == Integer.MAX_VALUE){
throw new IllegalStateException("Cannot read more than Integer.MAX_VALUE bytes");
}
//Double buffer, but allocate at least 1MB more
long increase = Math.max(buffer.length, MIN_BUFFER_STEP_SIZE);
int newBufferLength = (int)Math.min(Integer.MAX_VALUE, buffer.length + increase);
buffer = new byte[newBufferLength];
System.arraycopy(oldBuffer, 0, buffer, 0, oldBuffer.length);
numReadCurrent = is.read(buffer, oldBuffer.length, buffer.length - oldBuffer.length);
if(numReadCurrent > 0){
numReadTotal += numReadCurrent;
}
}
bufferMat = new Mat(buffer);
return bufferMat;
private Mat streamToMat(@NonNull InputStream is) throws IOException {
byte[] bytearray = IOUtils.toByteArray(is); //Todo: can be very large
if(bytearray == null || bytearray.length <= 0 ) {
throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
}
Mat outputMat = new Mat(bytearray);
return outputMat;
}
public Image asImageMatrix(String filename) throws IOException {
@ -624,7 +563,6 @@ public class NativeImageLoader extends BaseImageLoader {
public INDArray asMatrix(Mat image) throws IOException {
INDArray ret = transformImage(image, null);
return ret.reshape(ArrayUtil.combine(new long[] {1}, ret.shape()));
}
@ -678,6 +616,7 @@ public class NativeImageLoader extends BaseImageLoader {
throw new IOException("Cannot convert from " + image.channels() + " to " + channels + " channels.");
}
image2 = new Mat();
if(image.rows() == 0 && image.cols() == 0) throw new RuntimeException("Cannot convert image with source dimensions 0x0");
cvtColor(image, image2, code);
image = image2;
}
@ -895,9 +834,12 @@ public class NativeImageLoader extends BaseImageLoader {
* @return INDArray
* @throws IOException
*/
private INDArray asMatrix(BytePointer bytes, long length) throws IOException {
PIXA pixa;
pixa = pixaReadMemMultipageTiff(bytes, length);
private INDArray asMatrix(@NonNull BytePointer bytes, long length) throws IOException {
//This is an array of PIX (due to multipage)
PIXA pixa = pixaReadMemMultipageTiff(bytes, length);
if(pixa == null) throw new RuntimeException("Error reading multipage PIX");
INDArray data;
INDArray currentD;
INDArrayIndex[] index = null;

View File

@ -122,7 +122,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
}
protected boolean containsFormat(String format) {
for (String format2 : imageLoader.getAllowedFormats())
for (String format2 : imageLoader.getALLOWED_FORMATS())
if (format.endsWith("." + format2))
return true;
return false;
@ -172,6 +172,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
//remove the root directory
FileSplit split1 = (FileSplit) split;
labels.remove(split1.getRootDir());
split1.close();
}
//To ensure consistent order for label assignment (irrespective of file iteration order), we want to sort the list of labels
@ -405,7 +406,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
@Override
public void close() throws IOException {
//No op
this.inputSplit.close();
}
@Override

View File

@ -272,6 +272,30 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
}
}
public List<Writable> record(URI uri, File f) throws IOException {
invokeListeners(uri);
if (imageLoader == null) {
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
}
Image image = this.imageLoader.asImageMatrix(f);
if(!nchw)
image.setImage(image.getImage().permute(0,2,3,1));
Nd4j.getAffinityManager().ensureLocation(image.getImage(), AffinityManager.Location.DEVICE);
List<Writable> ret = RecordConverter.toRecord(image.getImage());
if (appendLabel) {
List<ImageObject> imageObjectsForPath = labelProvider.getImageObjectsForPath(uri.getPath());
int nClasses = labels.size();
INDArray outLabel = Nd4j.create(1, 4 + nClasses, gridH, gridW);
label(image, imageObjectsForPath, outLabel, 0);
if(!nchw)
outLabel = outLabel.permute(0,2,3,1); //NCHW to NHWC
ret.add(new NDArrayWritable(outLabel));
}
return ret;
}
@Override
public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
invokeListeners(uri);

View File

@ -21,14 +21,24 @@
package org.datavec.image;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.recordreader.ImageRecordReader;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.CleanupMode;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.io.IOException;
import java.nio.file.FileVisitResult;
import java.nio.file.FileVisitor;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.Arrays;
import java.util.List;
@ -37,32 +47,28 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class LabelGeneratorTest {
@TempDir
public File testDir;
@Test
public void testParentPathLabelGenerator() throws Exception {
@ParameterizedTest
@ValueSource(strings = {"m", "m.", "something"})
public void testParentPathLabelGenerator(String dirPrefix, @TempDir Path testDir) throws Exception {
//https://github.com/deeplearning4j/DataVec/issues/273
File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile();
for(String dirPrefix : new String[]{"m.", "m"}) {
File f = testDir;
int numDirs = 3;
int filesPerDir = 4;
for (int i = 0; i < numDirs; i++) {
File currentLabelDir = new File(f, dirPrefix + i);
currentLabelDir.mkdirs();
File currentLabelDir = new File(testDir.toFile(), dirPrefix + i);
for (int j = 0; j < filesPerDir; j++) {
File f3 = new File(currentLabelDir, "myImg_" + j + ".jpg");
FileUtils.copyFile(orig, f3);
FileUtils.copyFile(orig, f3); //will create directories as needed
assertTrue(f3.exists());
}
}
ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
rr.initialize(new FileSplit(f));
rr.initialize(new FileSplit(testDir));
List<String> labelsAct = rr.getLabels();
List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2");
@ -74,7 +80,9 @@ public class LabelGeneratorTest {
rr.next();
actCount++;
}
rr.close();
assertEquals(expCount, actCount);
}
}
}

View File

@ -28,29 +28,36 @@ import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.Java2DFrameConverter;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.bytedeco.leptonica.PIX;
import org.bytedeco.leptonica.PIXCMAP;
import org.bytedeco.opencv.opencv_core.Mat;
import org.datavec.image.data.Image;
import org.datavec.image.data.ImageWritable;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
import java.awt.image.BufferedImage;
import java.io.*;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.file.Path;
import java.util.Random;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.bytedeco.leptonica.*;
import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.leptonica.global.lept.*;
import static org.bytedeco.opencv.global.opencv_core.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assertions.*;
/**
*
@ -61,73 +68,74 @@ public class TestNativeImageLoader {
static final long seed = 10;
static final Random rng = new Random(seed);
@TempDir
public File testDir;
@ParameterizedTest(name = "#{index} - Run test with arguments WxHxD {0}, {1}, {2}")
@MethodSource("generateDimensions")
public void testConvertPIX(int width, int height, int depth, int matType) {
PIX pix = pixCreate(width, height, depth);
Mat mat = NativeImageLoader.convert(pix);
assertEquals(width, mat.cols());
assertEquals(height, mat.rows());
if(matType==CV_8UC4) matType= CV_8UC1; //this would be for 8 bit 256 gradients
assertEquals(matType, mat.type());
}
/**
* Run PIX creation and conversation test with ColorMap
* @param width
* @param height
* @param depth
* @param matType
*/
@ParameterizedTest(name = "#{index} - Run test with arguments WxHxD {0}, {1}, {2}")
@MethodSource("generateDimensions")
public void testConvertPIXCMAP(int width, int height, int depth, int matType) {
// a GIF file, for example
PIX pix = pixCreate(width, height, depth);
PIXCMAP cmap = pixcmapCreateLinear(depth, 256);
pixSetColormap(pix, cmap);
Mat mat = NativeImageLoader.convert(pix);
assertEquals(width, mat.cols());
assertEquals(height, mat.rows());
if( matType == CV_8UC1 && depth >= 8 ) matType = CV_8UC4; //change the argument, as this is depth 8 with 256 shades
assertEquals(matType, mat.type());
}
// Static stream of arguments
static Stream<Arguments> generateDimensions() {
return Stream.of(
Arguments.arguments(20, 20 ,1, CV_8UC1),
Arguments.arguments(1, 1, 1, CV_8UC1),
Arguments.arguments(1014, 1080, 1, CV_8UC1),
Arguments.arguments(20, 20 ,2, CV_8UC1),
Arguments.arguments(1, 1, 2, CV_8UC1),
Arguments.arguments(1014, 1080, 2, CV_8UC1),
Arguments.arguments(20, 20 ,4, CV_8UC1),
Arguments.arguments(1, 1, 4, CV_8UC1),
Arguments.arguments(1014, 1080, 4, CV_8UC1),
Arguments.arguments(20, 20 ,8, CV_8UC1),
Arguments.arguments(1, 1, 8, CV_8UC1),
Arguments.arguments(1014, 1080, 16, CV_16UC(1)),
Arguments.arguments(1014, 1080, 24, CV_8UC(3)),
Arguments.arguments(1014, 1080, 32, CV_32FC1),
Arguments.arguments(2048, 4096, 32, CV_32FC1),
Arguments.arguments(300, 300, 8, CV_8UC4)
);
}
@Test
public void testConvertPix() throws Exception {
PIX pix;
Mat mat;
pix = pixCreate(11, 22, 1);
mat = NativeImageLoader.convert(pix);
assertEquals(11, mat.cols());
assertEquals(22, mat.rows());
assertEquals(CV_8UC1, mat.type());
pix = pixCreate(33, 44, 2);
mat = NativeImageLoader.convert(pix);
assertEquals(33, mat.cols());
assertEquals(44, mat.rows());
assertEquals(CV_8UC1, mat.type());
pix = pixCreate(55, 66, 4);
mat = NativeImageLoader.convert(pix);
assertEquals(55, mat.cols());
assertEquals(66, mat.rows());
assertEquals(CV_8UC1, mat.type());
pix = pixCreate(77, 88, 8);
mat = NativeImageLoader.convert(pix);
assertEquals(77, mat.cols());
assertEquals(88, mat.rows());
assertEquals(CV_8UC1, mat.type());
pix = pixCreate(99, 111, 16);
mat = NativeImageLoader.convert(pix);
assertEquals(99, mat.cols());
assertEquals(111, mat.rows());
assertEquals(CV_16UC(1), mat.type());
pix = pixCreate(222, 333, 24);
mat = NativeImageLoader.convert(pix);
assertEquals(222, mat.cols());
assertEquals(333, mat.rows());
assertEquals(CV_8UC(3), mat.type());
pix = pixCreate(444, 555, 32);
mat = NativeImageLoader.convert(pix);
assertEquals(444, mat.cols());
assertEquals(555, mat.rows());
assertEquals(CV_32FC1, mat.type());
// a GIF file, for example
pix = pixCreate(32, 32, 8);
PIXCMAP cmap = pixcmapCreateLinear(8, 256);
pixSetColormap(pix, cmap);
mat = NativeImageLoader.convert(pix);
assertEquals(32, mat.cols());
assertEquals(32, mat.rows());
assertEquals(CV_8UC4, mat.type());
int w4 = 100, h4 = 238, ch4 = 1, pages = 1, depth = 1;
String path2MitosisFile = "datavec-data-image/testimages2/mitosis.tif";
NativeImageLoader loader5 = new NativeImageLoader(h4, w4, ch4, NativeImageLoader.MultiPageMode.FIRST);
INDArray array6 = null;
try {
array6 = loader5.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
File f = new ClassPathResource(path2MitosisFile).getFile();
assertTrue(!f.isDirectory() && f.canRead());
array6 = loader5.asMatrix( f );
} catch (IOException e) {
log.error("",e);
System.out.println(e.getMessage());
fail();
}
assertEquals(5, array6.rank());
@ -158,7 +166,7 @@ public class TestNativeImageLoader {
NativeImageLoader loader7 = new NativeImageLoader(h4, w4, ch6, NativeImageLoader.MultiPageMode.MINIBATCH);
INDArray array8 = null;
try {
array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile());
} catch (IOException e) {
log.error("",e);
}
@ -174,7 +182,7 @@ public class TestNativeImageLoader {
NativeImageLoader loader8 = new NativeImageLoader(h5, w5, ch6, NativeImageLoader.MultiPageMode.MINIBATCH);
INDArray array9 = null;
try {
array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile().getAbsolutePath());
array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile());
} catch (IOException e) {
log.error("",e);
fail();
@ -481,7 +489,7 @@ public class TestNativeImageLoader {
int w1 = 33, h1 = 77, ch1 = 1;
NativeImageLoader loader1 = new NativeImageLoader(h1, w1, ch1);
INDArray array1 = loader1.asMatrix(f0);
INDArray array1 = loader1.asMatrix(new File(f0));
assertEquals(4, array1.rank());
assertEquals(1, array1.size(0));
assertEquals(1, array1.size(1));
@ -565,40 +573,43 @@ public class TestNativeImageLoader {
@Test
public void testNativeImageLoaderEmptyStreams() throws Exception {
File dir = testDir;
public void testNativeImageLoaderEmptyStreams(@TempDir Path tempDir) throws Exception {
File dir = tempDir.toFile();
File f = new File(dir, "myFile.jpg");
f.createNewFile();
NativeImageLoader nil = new NativeImageLoader(32, 32, 3);
try(InputStream is = new FileInputStream(f)){
nil.asMatrix(is);
try {
nil.asMatrix(f);
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg.contains("decode image"), msg);
}
try(InputStream is = new FileInputStream(f)){
nil.asImageMatrix(is);
try {
nil.asImageMatrix(f);
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg.contains("decode image"), msg);
} catch (RuntimeException e) {
String msg = e.getMessage();
assertTrue(msg.contains("Cannot convert image with source dimensions 0x0"));
}
try{
nil.asRowVector(f);
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg.contains("decode image"), msg);
}
try(InputStream is = new FileInputStream(f)){
nil.asRowVector(is);
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg.contains("decode image"), msg);
}
try(InputStream is = new FileInputStream(f)){
try{
INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32);
nil.asMatrixView(is, arr);
nil.asMatrixView(f, arr);
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();

View File

@ -20,6 +20,8 @@
package org.nd4j.common.io;
import lombok.Getter;
import lombok.NonNull;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.IOUtils;
@ -38,43 +40,32 @@ import java.util.zip.ZipFile;
public class ClassPathResource extends AbstractFileResolvingResource {
private final String path;
private ClassLoader classLoader;
@Getter private final String path;
@Getter private final ClassLoader classLoader;
private Class<?> clazz;
public ClassPathResource(String path) {
this(path, (ClassLoader) null);
public ClassPathResource(@NonNull String path) {
this(path, ND4JClassLoading.getNd4jClassloader());
}
public ClassPathResource(String path, ClassLoader classLoader) {
Assert.notNull(path, "Path must not be null");
public ClassPathResource(@NonNull String path, @NonNull ClassLoader classLoader) {
String pathToUse = StringUtils.cleanPath(path);
if (pathToUse.startsWith("/")) {
pathToUse = pathToUse.substring(1);
}
this.path = pathToUse;
this.classLoader = classLoader != null ? classLoader : ND4JClassLoading.getNd4jClassloader();
}
public ClassPathResource(String path, Class<?> clazz) {
Assert.notNull(path, "Path must not be null");
this.path = StringUtils.cleanPath(path);
this.clazz = clazz;
}
protected ClassPathResource(String path, ClassLoader classLoader, Class<?> clazz) {
this.path = StringUtils.cleanPath(path);
this.classLoader = classLoader;
}
public ClassPathResource(@NonNull String path, @NonNull Class<?> clazz) {
this(path, clazz.getClassLoader());
this.clazz = clazz;
}
public final String getPath() {
return this.path;
}
public final ClassLoader getClassLoader() {
return this.classLoader != null ? this.classLoader : this.clazz.getClassLoader();
protected ClassPathResource(@NonNull String path, @NonNull ClassLoader classLoader, @NonNull Class<?> clazz) {
this(path, classLoader);
this.clazz = clazz;
}
/**
@ -133,14 +124,12 @@ public class ClassPathResource extends AbstractFileResolvingResource {
} else {
tmpFile = Files.createTempFile(FilenameUtils.getName(path), "tmp").toFile();
}
tmpFile.deleteOnExit();
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(tmpFile));
IOUtils.copy(is, bos);
bos.flush();
bos.close();
is.close();
return tmpFile;
}

View File

@ -222,6 +222,7 @@ chipList.each { thisChip ->
if (project.hasProperty("skip-native") && project.getProperty("skip-native").equals("true")) {
enabled = false
}
dependsOn "processResources"
properties = getBuildPlatform( thisChip, it )