From 4920f22fffaf00aafa4373252ef1d6fe705f79f5 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 10 Dec 2019 12:11:05 +1100 Subject: [PATCH] Check for empty streams for NativeImageLoader + test (#121) Signed-off-by: AlexDBlack --- .../image/loader/NativeImageLoader.java | 8 ++++ .../image/loader/TestNativeImageLoader.java | 44 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java index 8f482846b..d2be87536 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java @@ -24,6 +24,7 @@ import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.image.data.Image; import org.datavec.image.data.ImageWritable; import org.datavec.image.transform.ImageTransform; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.memory.pointers.PagedPointer; import org.nd4j.linalg.api.ndarray.INDArray; @@ -284,6 +285,9 @@ public class NativeImageLoader extends BaseImageLoader { private Mat streamToMat(InputStream is) throws IOException { if(buffer == null){ buffer = IOUtils.toByteArray(is); + if(buffer.length <= 0){ + throw new IOException("Could not decode image from input stream: input stream was empty (no data)"); + } bufferMat = new Mat(buffer); return bufferMat; } else { @@ -292,6 +296,10 @@ public class NativeImageLoader extends BaseImageLoader { //(a) if numRead < buffer.length - got everything //(b) if numRead >= buffer.length: we MIGHT have got everything (exact right size buffer) OR we need more data + if(numReadTotal <= 0){ + throw new IOException("Could not decode image from input stream: input stream was empty (no data)"); + } + if(numReadTotal < buffer.length){ bufferMat.data().put(buffer, 0, numReadTotal); bufferMat.cols(numReadTotal); diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java index 544e46b77..5f634bab8 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java @@ -24,7 +24,9 @@ import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.Java2DFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.image.data.ImageWritable; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -55,6 +57,9 @@ public class TestNativeImageLoader { static final long seed = 10; static final Random rng = new Random(seed); + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + @Test public void testConvertPix() throws Exception { PIX pix; @@ -554,4 +559,43 @@ public class TestNativeImageLoader { assertEquals(img1LargeBuffer, img1ExactBuffer); } + + @Test + public void testNativeImageLoaderEmptyStreams() throws Exception { + File dir = testDir.newFolder(); + File f = new File(dir, "myFile.jpg"); + f.createNewFile(); + + NativeImageLoader nil = new NativeImageLoader(32, 32, 3); + + try(InputStream is = new FileInputStream(f)){ + nil.asMatrix(is); + } catch (IOException e){ + String msg = e.getMessage(); + assertTrue(msg, msg.contains("decode image")); + } + + try(InputStream is = new FileInputStream(f)){ + nil.asImageMatrix(is); + } catch (IOException e){ + String msg = e.getMessage(); + assertTrue(msg, msg.contains("decode image")); + } + + try(InputStream is = new FileInputStream(f)){ + nil.asRowVector(is); + } catch (IOException e){ + String msg = e.getMessage(); + assertTrue(msg, msg.contains("decode image")); + } + + try(InputStream is = new FileInputStream(f)){ + INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32); + nil.asMatrixView(is, arr); + } catch (IOException e){ + String msg = e.getMessage(); + assertTrue(msg, msg.contains("decode image")); + } + } + }