From ee3e059b12ac4994289f17319792f557982391c9 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 5 Jun 2020 11:49:02 +1000 Subject: [PATCH] DL4J/DataVec: Fix Yolo2OutputLayer and ObjectDetectionRecordReader support for NHWC data format (#483) * Fix Yolo2OutputLayer for NHWC data format Signed-off-by: Alex Black * ObjectDetectionRecordReader NHWC support Signed-off-by: Alex Black --- .../ObjectDetectionRecordReader.java | 57 +++- .../TestObjectDetectionRecordReader.java | 281 +++++++++--------- .../gradientcheck/YoloGradientCheckTests.java | 39 ++- .../layers/objdetect/Yolo2OutputLayer.java | 6 +- .../nn/layers/objdetect/Yolo2OutputLayer.java | 15 +- .../nn/layers/objdetect/YoloUtils.java | 20 +- 6 files changed, 260 insertions(+), 158 deletions(-) diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java index 1a53a05ac..38afd6adf 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java @@ -49,7 +49,7 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point; /** * An image record reader for object detection. *

- * Format of returned values: 4d array, with dimensions [minibatch, 4+C, h, w] + * Format of returned values: 4d array, with dimensions [minibatch, 4+C, h, w] (nchw) or [minibatch, h, w, 4+C] (nhwc) * Where the image is quantized into h x w grid locations. *

* Note that this matches the format required for Deeplearning4j's Yolo2OutputLayer @@ -61,42 +61,67 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader { private final int gridW; private final int gridH; private final ImageObjectLabelProvider labelProvider; + private final boolean nchw; protected Image currentImage; /** + * As per {@link #ObjectDetectionRecordReader(int, int, int, int, int, boolean, ImageObjectLabelProvider)} but hardcoded + * to NCHW format + */ + public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, ImageObjectLabelProvider labelProvider) { + this(height, width, channels, gridH, gridW, true, labelProvider); + } + + /** + * Create ObjectDetectionRecordReader with * * @param height Height of the output images * @param width Width of the output images * @param channels Number of channels for the output images * @param gridH Grid/quantization size (along height dimension) - Y axis * @param gridW Grid/quantization size (along height dimension) - X axis + * @param nchw If true: return NCHW format labels with array shape [minibatch, 4+C, h, w]; if false, return + * NHWC format labels with array shape [minibatch, h, w, 4+C] * @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image */ - public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, ImageObjectLabelProvider labelProvider) { + public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, boolean nchw, ImageObjectLabelProvider labelProvider) { super(height, width, channels, null, null); this.gridW = gridW; this.gridH = gridH; + this.nchw = nchw; this.labelProvider = labelProvider; this.appendLabel = labelProvider != null; } /** - * When imageTransform != null, object is removed if new center is outside of transformed image bounds. - * - * @param height Height of the output images - * @param width Width of the output images - * @param channels Number of channels for the output images - * @param gridH Grid/quantization size (along height dimension) - Y axis - * @param gridW Grid/quantization size (along height dimension) - X axis - * @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image - * @param imageTransform ImageTransform - used to transform image and coordinates + * As per {@link #ObjectDetectionRecordReader(int, int, int, int, int, boolean, ImageObjectLabelProvider, ImageTransform)} + * but hardcoded to NCHW format */ public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, - ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) { + ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) { + this(height, width, channels, gridH, gridW, true, labelProvider, imageTransform); + } + + /** + * When imageTransform != null, object is removed if new center is outside of transformed image bounds. + * + * @param height Height of the output images + * @param width Width of the output images + * @param channels Number of channels for the output images + * @param gridH Grid/quantization size (along height dimension) - Y axis + * @param gridW Grid/quantization size (along height dimension) - X axis + * @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image + * @param nchw If true: return NCHW format labels with array shape [minibatch, 4+C, h, w]; if false, return + * NHWC format labels with array shape [minibatch, h, w, 4+C] + * @param imageTransform ImageTransform - used to transform image and coordinates + */ + public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, boolean nchw, + ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) { super(height, width, channels, null, null); this.gridW = gridW; this.gridH = gridH; + this.nchw = nchw; this.labelProvider = labelProvider; this.appendLabel = labelProvider != null; this.imageTransform = imageTransform; @@ -182,6 +207,10 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader { exampleNum++; } + if(!nchw) { + outImg = outImg.permute(0, 2, 3, 1); //NCHW to NHWC + outLabel = outLabel.permute(0, 2, 3, 1); + } return new NDArrayRecordBatch(Arrays.asList(outImg, outLabel)); } @@ -256,6 +285,8 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader { imageLoader = new NativeImageLoader(height, width, channels, imageTransform); } Image image = this.imageLoader.asImageMatrix(dataInputStream); + if(!nchw) + image.setImage(image.getImage().permute(0,2,3,1)); Nd4j.getAffinityManager().ensureLocation(image.getImage(), AffinityManager.Location.DEVICE); List ret = RecordConverter.toRecord(image.getImage()); @@ -264,6 +295,8 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader { int nClasses = labels.size(); INDArray outLabel = Nd4j.create(1, 4 + nClasses, gridH, gridW); label(image, imageObjectsForPath, outLabel, 0); + if(!nchw) + outLabel = outLabel.permute(0,2,3,1); //NCHW to NHWC ret.add(new NDArrayWritable(outLabel)); } return ret; diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java index d8620096a..5e4598005 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java @@ -56,168 +56,179 @@ public class TestObjectDetectionRecordReader { @Test public void test() throws Exception { - ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider(); + for(boolean nchw : new boolean[]{true, false}) { + ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider(); - File f = testDir.newFolder(); - new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f); + File f = testDir.newFolder(); + new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f); - String path = new File(f, "000012.jpg").getParent(); + String path = new File(f, "000012.jpg").getParent(); - int h = 32; - int w = 32; - int c = 3; - int gW = 13; - int gH = 10; + int h = 32; + int w = 32; + int c = 3; + int gW = 13; + int gH = 10; - //Enforce consistent iteration order for tests - URI[] u = new FileSplit(new File(path)).locations(); - Arrays.sort(u); + //Enforce consistent iteration order for tests + URI[] u = new FileSplit(new File(path)).locations(); + Arrays.sort(u); - RecordReader rr = new ObjectDetectionRecordReader(h, w, c, gH, gW, lp); - rr.initialize(new CollectionInputSplit(u)); + RecordReader rr = new ObjectDetectionRecordReader(h, w, c, gH, gW, nchw, lp); + rr.initialize(new CollectionInputSplit(u)); - RecordReader imgRR = new ImageRecordReader(h, w, c); - imgRR.initialize(new CollectionInputSplit(u)); + RecordReader imgRR = new ImageRecordReader(h, w, c, nchw); + imgRR.initialize(new CollectionInputSplit(u)); - List labels = rr.getLabels(); - assertEquals(Arrays.asList("car", "cat"), labels); + List labels = rr.getLabels(); + assertEquals(Arrays.asList("car", "cat"), labels); - //000012.jpg - originally 500x333 - //000019.jpg - originally 500x375 - double[] origW = new double[]{500, 500}; - double[] origH = new double[]{333, 375}; - List> l = Arrays.asList( - Collections.singletonList(new ImageObject(156, 97, 351, 270, "car")), - Arrays.asList(new ImageObject(11, 113, 266, 259, "cat"), new ImageObject(231, 88, 483, 256, "cat")) - ); + //000012.jpg - originally 500x333 + //000019.jpg - originally 500x375 + double[] origW = new double[]{500, 500}; + double[] origH = new double[]{333, 375}; + List> l = Arrays.asList( + Collections.singletonList(new ImageObject(156, 97, 351, 270, "car")), + Arrays.asList(new ImageObject(11, 113, 266, 259, "cat"), new ImageObject(231, 88, 483, 256, "cat")) + ); - for (int idx = 0; idx < 2; idx++) { - assertTrue(rr.hasNext()); - List next = rr.next(); - List nextImgRR = imgRR.next(); + for (int idx = 0; idx < 2; idx++) { + assertTrue(rr.hasNext()); + List next = rr.next(); + List nextImgRR = imgRR.next(); - //Check features: - assertEquals(next.get(0), nextImgRR.get(0)); + //Check features: + assertEquals(next.get(0), nextImgRR.get(0)); - //Check labels - assertEquals(2, next.size()); - assertTrue(next.get(0) instanceof NDArrayWritable); - assertTrue(next.get(1) instanceof NDArrayWritable); + //Check labels + assertEquals(2, next.size()); + assertTrue(next.get(0) instanceof NDArrayWritable); + assertTrue(next.get(1) instanceof NDArrayWritable); - List objects = l.get(idx); + List objects = l.get(idx); - INDArray expLabels = Nd4j.create(1, 4 + 2, gH, gW); - for (ImageObject io : objects) { - double fracImageX1 = io.getX1() / origW[idx]; - double fracImageY1 = io.getY1() / origH[idx]; - double fracImageX2 = io.getX2() / origW[idx]; - double fracImageY2 = io.getY2() / origH[idx]; + INDArray expLabels = Nd4j.create(1, 4 + 2, gH, gW); + for (ImageObject io : objects) { + double fracImageX1 = io.getX1() / origW[idx]; + double fracImageY1 = io.getY1() / origH[idx]; + double fracImageX2 = io.getX2() / origW[idx]; + double fracImageY2 = io.getY2() / origH[idx]; - double x1C = (fracImageX1 + fracImageX2) / 2.0; - double y1C = (fracImageY1 + fracImageY2) / 2.0; + double x1C = (fracImageX1 + fracImageX2) / 2.0; + double y1C = (fracImageY1 + fracImageY2) / 2.0; - int labelGridX = (int) (x1C * gW); - int labelGridY = (int) (y1C * gH); + int labelGridX = (int) (x1C * gW); + int labelGridY = (int) (y1C * gH); - int labelIdx; - if (io.getLabel().equals("car")) { - labelIdx = 4; - } else { - labelIdx = 5; + int labelIdx; + if (io.getLabel().equals("car")) { + labelIdx = 4; + } else { + labelIdx = 5; + } + expLabels.putScalar(0, labelIdx, labelGridY, labelGridX, 1.0); + + expLabels.putScalar(0, 0, labelGridY, labelGridX, fracImageX1 * gW); + expLabels.putScalar(0, 1, labelGridY, labelGridX, fracImageY1 * gH); + expLabels.putScalar(0, 2, labelGridY, labelGridX, fracImageX2 * gW); + expLabels.putScalar(0, 3, labelGridY, labelGridX, fracImageY2 * gH); } - expLabels.putScalar(0, labelIdx, labelGridY, labelGridX, 1.0); - expLabels.putScalar(0, 0, labelGridY, labelGridX, fracImageX1 * gW); - expLabels.putScalar(0, 1, labelGridY, labelGridX, fracImageY1 * gH); - expLabels.putScalar(0, 2, labelGridY, labelGridX, fracImageX2 * gW); - expLabels.putScalar(0, 3, labelGridY, labelGridX, fracImageY2 * gH); + INDArray lArr = ((NDArrayWritable) next.get(1)).get(); + if(nchw) { + assertArrayEquals(new long[]{1, 4 + 2, gH, gW}, lArr.shape()); + } else { + assertArrayEquals(new long[]{1, gH, gW, 4 + 2}, lArr.shape()); + } + + if(!nchw) + expLabels = expLabels.permute(0,2,3,1); //NCHW to NHWC + + assertEquals(expLabels, lArr); } - INDArray lArr = ((NDArrayWritable) next.get(1)).get(); - assertArrayEquals(new long[]{1, 4 + 2, gH, gW}, lArr.shape()); - assertEquals(expLabels, lArr); - } + rr.reset(); + Record record = rr.nextRecord(); + RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) record.getMetaData(); + assertEquals(new File(path, "000012.jpg"), new File(metadata.getURI())); + assertEquals(3, metadata.getOrigC()); + assertEquals((int) origH[0], metadata.getOrigH()); + assertEquals((int) origW[0], metadata.getOrigW()); - rr.reset(); - Record record = rr.nextRecord(); - RecordMetaDataImageURI metadata = (RecordMetaDataImageURI)record.getMetaData(); - assertEquals(new File(path, "000012.jpg"), new File(metadata.getURI())); - assertEquals(3, metadata.getOrigC()); - assertEquals((int)origH[0], metadata.getOrigH()); - assertEquals((int)origW[0], metadata.getOrigW()); + List out = new ArrayList<>(); + List meta = new ArrayList<>(); + out.add(record); + meta.add(metadata); + record = rr.nextRecord(); + metadata = (RecordMetaDataImageURI) record.getMetaData(); + out.add(record); + meta.add(metadata); - List out = new ArrayList<>(); - List meta = new ArrayList<>(); - out.add(record); - meta.add(metadata); - record = rr.nextRecord(); - metadata = (RecordMetaDataImageURI)record.getMetaData(); - out.add(record); - meta.add(metadata); + List fromMeta = rr.loadFromMetaData(meta); + assertEquals(out, fromMeta); - List fromMeta = rr.loadFromMetaData(meta); - assertEquals(out, fromMeta); + // make sure we don't lose objects just by explicitly resizing + int i = 0; + int[] nonzeroCount = {5, 10}; - // make sure we don't lose objects just by explicitly resizing - int i = 0; - int[] nonzeroCount = {5, 10}; + ImageTransform transform = new ResizeImageTransform(37, 42); + RecordReader rrTransform = new ObjectDetectionRecordReader(42, 37, c, gH, gW, nchw, lp, transform); + rrTransform.initialize(new CollectionInputSplit(u)); + i = 0; + while (rrTransform.hasNext()) { + List next = rrTransform.next(); + assertEquals(37, transform.getCurrentImage().getWidth()); + assertEquals(42, transform.getCurrentImage().getHeight()); + INDArray labelArray = ((NDArrayWritable) next.get(1)).get(); + BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); + assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); + } - ImageTransform transform = new ResizeImageTransform(37, 42); - RecordReader rrTransform = new ObjectDetectionRecordReader(42, 37, c, gH, gW, lp, transform); - rrTransform.initialize(new CollectionInputSplit(u)); - i = 0; - while (rrTransform.hasNext()) { - List next = rrTransform.next(); - assertEquals(37, transform.getCurrentImage().getWidth()); - assertEquals(42, transform.getCurrentImage().getHeight()); - INDArray labelArray = ((NDArrayWritable)next.get(1)).get(); - BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); - assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); - } + ImageTransform transform2 = new ResizeImageTransform(1024, 2048); + RecordReader rrTransform2 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, nchw, lp, transform2); + rrTransform2.initialize(new CollectionInputSplit(u)); + i = 0; + while (rrTransform2.hasNext()) { + List next = rrTransform2.next(); + assertEquals(1024, transform2.getCurrentImage().getWidth()); + assertEquals(2048, transform2.getCurrentImage().getHeight()); + INDArray labelArray = ((NDArrayWritable) next.get(1)).get(); + BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); + assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); + } + + //Make sure image flip does not break labels and are correct for new image size dimensions: + ImageTransform transform3 = new PipelineImageTransform( + new ResizeImageTransform(2048, 4096), + new FlipImageTransform(-1) + ); + RecordReader rrTransform3 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, nchw, lp, transform3); + rrTransform3.initialize(new CollectionInputSplit(u)); + i = 0; + while (rrTransform3.hasNext()) { + List next = rrTransform3.next(); + INDArray labelArray = ((NDArrayWritable) next.get(1)).get(); + BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); + assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); + } + + //Test that doing a downscale with the native image loader directly instead of a transform does not cause an exception: + ImageTransform transform4 = new FlipImageTransform(-1); + RecordReader rrTransform4 = new ObjectDetectionRecordReader(128, 128, c, gH, gW, nchw, lp, transform4); + rrTransform4.initialize(new CollectionInputSplit(u)); + i = 0; + while (rrTransform4.hasNext()) { + List next = rrTransform4.next(); + + assertEquals((int) origW[i], transform4.getCurrentImage().getWidth()); + assertEquals((int) origH[i], transform4.getCurrentImage().getHeight()); + + INDArray labelArray = ((NDArrayWritable) next.get(1)).get(); + BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); + assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); + } - ImageTransform transform2 = new ResizeImageTransform(1024, 2048); - RecordReader rrTransform2 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, lp, transform2); - rrTransform2.initialize(new CollectionInputSplit(u)); - i = 0; - while (rrTransform2.hasNext()) { - List next = rrTransform2.next(); - assertEquals(1024, transform2.getCurrentImage().getWidth()); - assertEquals(2048, transform2.getCurrentImage().getHeight()); - INDArray labelArray = ((NDArrayWritable)next.get(1)).get(); - BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); - assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); - } - - //Make sure image flip does not break labels and are correct for new image size dimensions: - ImageTransform transform3 = new PipelineImageTransform( - new ResizeImageTransform(2048, 4096), - new FlipImageTransform(-1) - ); - RecordReader rrTransform3 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, lp, transform3); - rrTransform3.initialize(new CollectionInputSplit(u)); - i = 0; - while (rrTransform3.hasNext()) { - List next = rrTransform3.next(); - INDArray labelArray = ((NDArrayWritable)next.get(1)).get(); - BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); - assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); - } - - //Test that doing a downscale with the native image loader directly instead of a transform does not cause an exception: - ImageTransform transform4 = new FlipImageTransform(-1); - RecordReader rrTransform4 = new ObjectDetectionRecordReader(128, 128, c, gH, gW, lp, transform4); - rrTransform4.initialize(new CollectionInputSplit(u)); - i = 0; - while (rrTransform4.hasNext()) { - List next = rrTransform4.next(); - - assertEquals((int) origW[i], transform4.getCurrentImage().getWidth()); - assertEquals((int) origH[i], transform4.getCurrentImage().getHeight()); - - INDArray labelArray = ((NDArrayWritable)next.get(1)).get(); - BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); - assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java index 5646b6519..47c040c12 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java @@ -24,9 +24,7 @@ import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; -import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.distribution.GaussianDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -36,6 +34,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -50,17 +50,28 @@ import java.io.File; import java.io.FileOutputStream; import java.io.InputStream; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertTrue; /** * @author Alex Black */ +@RunWith(Parameterized.class) public class YoloGradientCheckTests extends BaseDL4JTest { static { Nd4j.setDataType(DataType.DOUBLE); } + private CNN2DFormat format; + public YoloGradientCheckTests(CNN2DFormat format){ + this.format = format; + } + @Parameterized.Parameters(name = "{0}") + public static Object[] params(){ + return CNN2DFormat.values(); + } + @Rule public TemporaryFolder testDir = new TemporaryFolder(); @@ -97,8 +108,14 @@ public class YoloGradientCheckTests extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); - INDArray input = Nd4j.rand(new int[]{mb, depthIn, h, w}); - INDArray labels = yoloLabels(mb, c, h, w); + INDArray input, labels; + if(format == CNN2DFormat.NCHW){ + input = Nd4j.rand(DataType.DOUBLE, mb, depthIn, h, w); + labels = yoloLabels(mb, c, h, w); + } else { + input = Nd4j.rand(DataType.DOUBLE, mb, h, w, depthIn); + labels = yoloLabels(mb, c, h, w).permute(0,2,3,1); + } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) .dataType(DataType.DOUBLE) @@ -112,6 +129,7 @@ public class YoloGradientCheckTests extends BaseDL4JTest { .layer(new Yolo2OutputLayer.Builder() .boundingBoxPriors(bbPrior) .build()) + .setInputType(InputType.convolutional(h, w, depthIn, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -120,7 +138,18 @@ public class YoloGradientCheckTests extends BaseDL4JTest { String msg = "testYoloOutputLayer() - minibatch = " + mb + ", w=" + w + ", h=" + h + ", l1=" + l1[i] + ", l2=" + l2[i]; System.out.println(msg); + INDArray out = net.output(input); + if(format == CNN2DFormat.NCHW){ + assertArrayEquals(new long[]{mb, yoloDepth, h, w}, out.shape()); + } else { + assertArrayEquals(new long[]{mb, h, w, yoloDepth}, out.shape()); + } + + net.fit(input, labels); + + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) + .minAbsoluteError(1e-6) .labels(labels).subset(true).maxPerParam(100)); assertTrue(msg, gradOK); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java index 24bda07f6..6ffb92978 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java @@ -21,6 +21,7 @@ import lombok.Getter; import lombok.Setter; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -80,6 +81,8 @@ public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer { @JsonDeserialize(using = BoundingBoxesDeserializer.class) private INDArray boundingBoxes; + private CNN2DFormat format = CNN2DFormat.NCHW; //Default for serialization of old formats + private Yolo2OutputLayer() { //No-arg constructor for Jackson JSON } @@ -119,7 +122,8 @@ public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer { @Override public void setNIn(InputType inputType, boolean override) { - //No op + InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; + this.format = c.getFormat(); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java index eb5a4d19e..4d118c62b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.objdetect; import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.layers.IOutputLayer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -110,6 +111,12 @@ public class Yolo2OutputLayer extends AbstractLayer