Check for empty streams for NativeImageLoader + test (#121)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-12-10 12:11:05 +11:00 committed by GitHub
parent a5f5ac72b1
commit 4920f22fff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 0 deletions

View File

@ -24,6 +24,7 @@ import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.Image; import org.datavec.image.data.Image;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.datavec.image.transform.ImageTransform; import org.datavec.image.transform.ImageTransform;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.pointers.PagedPointer; import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -284,6 +285,9 @@ public class NativeImageLoader extends BaseImageLoader {
private Mat streamToMat(InputStream is) throws IOException { private Mat streamToMat(InputStream is) throws IOException {
if(buffer == null){ if(buffer == null){
buffer = IOUtils.toByteArray(is); buffer = IOUtils.toByteArray(is);
if(buffer.length <= 0){
throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
}
bufferMat = new Mat(buffer); bufferMat = new Mat(buffer);
return bufferMat; return bufferMat;
} else { } else {
@ -292,6 +296,10 @@ public class NativeImageLoader extends BaseImageLoader {
//(a) if numRead < buffer.length - got everything //(a) if numRead < buffer.length - got everything
//(b) if numRead >= buffer.length: we MIGHT have got everything (exact right size buffer) OR we need more data //(b) if numRead >= buffer.length: we MIGHT have got everything (exact right size buffer) OR we need more data
if(numReadTotal <= 0){
throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
}
if(numReadTotal < buffer.length){ if(numReadTotal < buffer.length){
bufferMat.data().put(buffer, 0, numReadTotal); bufferMat.data().put(buffer, 0, numReadTotal);
bufferMat.cols(numReadTotal); bufferMat.cols(numReadTotal);

View File

@ -24,7 +24,9 @@ import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.Java2DFrameConverter; import org.bytedeco.javacv.Java2DFrameConverter;
import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -55,6 +57,9 @@ public class TestNativeImageLoader {
static final long seed = 10; static final long seed = 10;
static final Random rng = new Random(seed); static final Random rng = new Random(seed);
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testConvertPix() throws Exception { public void testConvertPix() throws Exception {
PIX pix; PIX pix;
@ -554,4 +559,43 @@ public class TestNativeImageLoader {
assertEquals(img1LargeBuffer, img1ExactBuffer); assertEquals(img1LargeBuffer, img1ExactBuffer);
} }
@Test
public void testNativeImageLoaderEmptyStreams() throws Exception {
File dir = testDir.newFolder();
File f = new File(dir, "myFile.jpg");
f.createNewFile();
NativeImageLoader nil = new NativeImageLoader(32, 32, 3);
try(InputStream is = new FileInputStream(f)){
nil.asMatrix(is);
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
}
try(InputStream is = new FileInputStream(f)){
nil.asImageMatrix(is);
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
}
try(InputStream is = new FileInputStream(f)){
nil.asRowVector(is);
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
}
try(InputStream is = new FileInputStream(f)){
INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32);
nil.asMatrixView(is, arr);
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
}
}
} }