diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java
index 1a53a05ac..38afd6adf 100644
--- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java
+++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java
@@ -49,7 +49,7 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point;
/**
* An image record reader for object detection.
*
- * Format of returned values: 4d array, with dimensions [minibatch, 4+C, h, w]
+ * Format of returned values: 4d array, with dimensions [minibatch, 4+C, h, w] (nchw) or [minibatch, h, w, 4+C] (nhwc)
* Where the image is quantized into h x w grid locations.
*
* Note that this matches the format required for Deeplearning4j's Yolo2OutputLayer
@@ -61,42 +61,67 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
private final int gridW;
private final int gridH;
private final ImageObjectLabelProvider labelProvider;
+ private final boolean nchw;
protected Image currentImage;
/**
+ * As per {@link #ObjectDetectionRecordReader(int, int, int, int, int, boolean, ImageObjectLabelProvider)} but hardcoded
+ * to NCHW format
+ */
+ public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, ImageObjectLabelProvider labelProvider) {
+ this(height, width, channels, gridH, gridW, true, labelProvider);
+ }
+
+ /**
+ * Create ObjectDetectionRecordReader with
*
* @param height Height of the output images
* @param width Width of the output images
* @param channels Number of channels for the output images
* @param gridH Grid/quantization size (along height dimension) - Y axis
* @param gridW Grid/quantization size (along height dimension) - X axis
+ * @param nchw If true: return NCHW format labels with array shape [minibatch, 4+C, h, w]; if false, return
+ * NHWC format labels with array shape [minibatch, h, w, 4+C]
* @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image
*/
- public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, ImageObjectLabelProvider labelProvider) {
+ public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, boolean nchw, ImageObjectLabelProvider labelProvider) {
super(height, width, channels, null, null);
this.gridW = gridW;
this.gridH = gridH;
+ this.nchw = nchw;
this.labelProvider = labelProvider;
this.appendLabel = labelProvider != null;
}
/**
- * When imageTransform != null, object is removed if new center is outside of transformed image bounds.
- *
- * @param height Height of the output images
- * @param width Width of the output images
- * @param channels Number of channels for the output images
- * @param gridH Grid/quantization size (along height dimension) - Y axis
- * @param gridW Grid/quantization size (along height dimension) - X axis
- * @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image
- * @param imageTransform ImageTransform - used to transform image and coordinates
+ * As per {@link #ObjectDetectionRecordReader(int, int, int, int, int, boolean, ImageObjectLabelProvider, ImageTransform)}
+ * but hardcoded to NCHW format
*/
public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW,
- ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) {
+ ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) {
+ this(height, width, channels, gridH, gridW, true, labelProvider, imageTransform);
+ }
+
+ /**
+ * When imageTransform != null, object is removed if new center is outside of transformed image bounds.
+ *
+ * @param height Height of the output images
+ * @param width Width of the output images
+ * @param channels Number of channels for the output images
+ * @param gridH Grid/quantization size (along height dimension) - Y axis
+ * @param gridW Grid/quantization size (along height dimension) - X axis
+ * @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image
+ * @param nchw If true: return NCHW format labels with array shape [minibatch, 4+C, h, w]; if false, return
+ * NHWC format labels with array shape [minibatch, h, w, 4+C]
+ * @param imageTransform ImageTransform - used to transform image and coordinates
+ */
+ public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, boolean nchw,
+ ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) {
super(height, width, channels, null, null);
this.gridW = gridW;
this.gridH = gridH;
+ this.nchw = nchw;
this.labelProvider = labelProvider;
this.appendLabel = labelProvider != null;
this.imageTransform = imageTransform;
@@ -182,6 +207,10 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
exampleNum++;
}
+ if(!nchw) {
+ outImg = outImg.permute(0, 2, 3, 1); //NCHW to NHWC
+ outLabel = outLabel.permute(0, 2, 3, 1);
+ }
return new NDArrayRecordBatch(Arrays.asList(outImg, outLabel));
}
@@ -256,6 +285,8 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
}
Image image = this.imageLoader.asImageMatrix(dataInputStream);
+ if(!nchw)
+ image.setImage(image.getImage().permute(0,2,3,1));
Nd4j.getAffinityManager().ensureLocation(image.getImage(), AffinityManager.Location.DEVICE);
List ret = RecordConverter.toRecord(image.getImage());
@@ -264,6 +295,8 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
int nClasses = labels.size();
INDArray outLabel = Nd4j.create(1, 4 + nClasses, gridH, gridW);
label(image, imageObjectsForPath, outLabel, 0);
+ if(!nchw)
+ outLabel = outLabel.permute(0,2,3,1); //NCHW to NHWC
ret.add(new NDArrayWritable(outLabel));
}
return ret;
diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java
index d8620096a..5e4598005 100644
--- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java
+++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java
@@ -56,168 +56,179 @@ public class TestObjectDetectionRecordReader {
@Test
public void test() throws Exception {
- ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider();
+ for(boolean nchw : new boolean[]{true, false}) {
+ ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider();
- File f = testDir.newFolder();
- new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f);
+ File f = testDir.newFolder();
+ new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f);
- String path = new File(f, "000012.jpg").getParent();
+ String path = new File(f, "000012.jpg").getParent();
- int h = 32;
- int w = 32;
- int c = 3;
- int gW = 13;
- int gH = 10;
+ int h = 32;
+ int w = 32;
+ int c = 3;
+ int gW = 13;
+ int gH = 10;
- //Enforce consistent iteration order for tests
- URI[] u = new FileSplit(new File(path)).locations();
- Arrays.sort(u);
+ //Enforce consistent iteration order for tests
+ URI[] u = new FileSplit(new File(path)).locations();
+ Arrays.sort(u);
- RecordReader rr = new ObjectDetectionRecordReader(h, w, c, gH, gW, lp);
- rr.initialize(new CollectionInputSplit(u));
+ RecordReader rr = new ObjectDetectionRecordReader(h, w, c, gH, gW, nchw, lp);
+ rr.initialize(new CollectionInputSplit(u));
- RecordReader imgRR = new ImageRecordReader(h, w, c);
- imgRR.initialize(new CollectionInputSplit(u));
+ RecordReader imgRR = new ImageRecordReader(h, w, c, nchw);
+ imgRR.initialize(new CollectionInputSplit(u));
- List labels = rr.getLabels();
- assertEquals(Arrays.asList("car", "cat"), labels);
+ List labels = rr.getLabels();
+ assertEquals(Arrays.asList("car", "cat"), labels);
- //000012.jpg - originally 500x333
- //000019.jpg - originally 500x375
- double[] origW = new double[]{500, 500};
- double[] origH = new double[]{333, 375};
- List> l = Arrays.asList(
- Collections.singletonList(new ImageObject(156, 97, 351, 270, "car")),
- Arrays.asList(new ImageObject(11, 113, 266, 259, "cat"), new ImageObject(231, 88, 483, 256, "cat"))
- );
+ //000012.jpg - originally 500x333
+ //000019.jpg - originally 500x375
+ double[] origW = new double[]{500, 500};
+ double[] origH = new double[]{333, 375};
+ List> l = Arrays.asList(
+ Collections.singletonList(new ImageObject(156, 97, 351, 270, "car")),
+ Arrays.asList(new ImageObject(11, 113, 266, 259, "cat"), new ImageObject(231, 88, 483, 256, "cat"))
+ );
- for (int idx = 0; idx < 2; idx++) {
- assertTrue(rr.hasNext());
- List next = rr.next();
- List nextImgRR = imgRR.next();
+ for (int idx = 0; idx < 2; idx++) {
+ assertTrue(rr.hasNext());
+ List next = rr.next();
+ List nextImgRR = imgRR.next();
- //Check features:
- assertEquals(next.get(0), nextImgRR.get(0));
+ //Check features:
+ assertEquals(next.get(0), nextImgRR.get(0));
- //Check labels
- assertEquals(2, next.size());
- assertTrue(next.get(0) instanceof NDArrayWritable);
- assertTrue(next.get(1) instanceof NDArrayWritable);
+ //Check labels
+ assertEquals(2, next.size());
+ assertTrue(next.get(0) instanceof NDArrayWritable);
+ assertTrue(next.get(1) instanceof NDArrayWritable);
- List objects = l.get(idx);
+ List objects = l.get(idx);
- INDArray expLabels = Nd4j.create(1, 4 + 2, gH, gW);
- for (ImageObject io : objects) {
- double fracImageX1 = io.getX1() / origW[idx];
- double fracImageY1 = io.getY1() / origH[idx];
- double fracImageX2 = io.getX2() / origW[idx];
- double fracImageY2 = io.getY2() / origH[idx];
+ INDArray expLabels = Nd4j.create(1, 4 + 2, gH, gW);
+ for (ImageObject io : objects) {
+ double fracImageX1 = io.getX1() / origW[idx];
+ double fracImageY1 = io.getY1() / origH[idx];
+ double fracImageX2 = io.getX2() / origW[idx];
+ double fracImageY2 = io.getY2() / origH[idx];
- double x1C = (fracImageX1 + fracImageX2) / 2.0;
- double y1C = (fracImageY1 + fracImageY2) / 2.0;
+ double x1C = (fracImageX1 + fracImageX2) / 2.0;
+ double y1C = (fracImageY1 + fracImageY2) / 2.0;
- int labelGridX = (int) (x1C * gW);
- int labelGridY = (int) (y1C * gH);
+ int labelGridX = (int) (x1C * gW);
+ int labelGridY = (int) (y1C * gH);
- int labelIdx;
- if (io.getLabel().equals("car")) {
- labelIdx = 4;
- } else {
- labelIdx = 5;
+ int labelIdx;
+ if (io.getLabel().equals("car")) {
+ labelIdx = 4;
+ } else {
+ labelIdx = 5;
+ }
+ expLabels.putScalar(0, labelIdx, labelGridY, labelGridX, 1.0);
+
+ expLabels.putScalar(0, 0, labelGridY, labelGridX, fracImageX1 * gW);
+ expLabels.putScalar(0, 1, labelGridY, labelGridX, fracImageY1 * gH);
+ expLabels.putScalar(0, 2, labelGridY, labelGridX, fracImageX2 * gW);
+ expLabels.putScalar(0, 3, labelGridY, labelGridX, fracImageY2 * gH);
}
- expLabels.putScalar(0, labelIdx, labelGridY, labelGridX, 1.0);
- expLabels.putScalar(0, 0, labelGridY, labelGridX, fracImageX1 * gW);
- expLabels.putScalar(0, 1, labelGridY, labelGridX, fracImageY1 * gH);
- expLabels.putScalar(0, 2, labelGridY, labelGridX, fracImageX2 * gW);
- expLabels.putScalar(0, 3, labelGridY, labelGridX, fracImageY2 * gH);
+ INDArray lArr = ((NDArrayWritable) next.get(1)).get();
+ if(nchw) {
+ assertArrayEquals(new long[]{1, 4 + 2, gH, gW}, lArr.shape());
+ } else {
+ assertArrayEquals(new long[]{1, gH, gW, 4 + 2}, lArr.shape());
+ }
+
+ if(!nchw)
+ expLabels = expLabels.permute(0,2,3,1); //NCHW to NHWC
+
+ assertEquals(expLabels, lArr);
}
- INDArray lArr = ((NDArrayWritable) next.get(1)).get();
- assertArrayEquals(new long[]{1, 4 + 2, gH, gW}, lArr.shape());
- assertEquals(expLabels, lArr);
- }
+ rr.reset();
+ Record record = rr.nextRecord();
+ RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) record.getMetaData();
+ assertEquals(new File(path, "000012.jpg"), new File(metadata.getURI()));
+ assertEquals(3, metadata.getOrigC());
+ assertEquals((int) origH[0], metadata.getOrigH());
+ assertEquals((int) origW[0], metadata.getOrigW());
- rr.reset();
- Record record = rr.nextRecord();
- RecordMetaDataImageURI metadata = (RecordMetaDataImageURI)record.getMetaData();
- assertEquals(new File(path, "000012.jpg"), new File(metadata.getURI()));
- assertEquals(3, metadata.getOrigC());
- assertEquals((int)origH[0], metadata.getOrigH());
- assertEquals((int)origW[0], metadata.getOrigW());
+ List out = new ArrayList<>();
+ List meta = new ArrayList<>();
+ out.add(record);
+ meta.add(metadata);
+ record = rr.nextRecord();
+ metadata = (RecordMetaDataImageURI) record.getMetaData();
+ out.add(record);
+ meta.add(metadata);
- List out = new ArrayList<>();
- List meta = new ArrayList<>();
- out.add(record);
- meta.add(metadata);
- record = rr.nextRecord();
- metadata = (RecordMetaDataImageURI)record.getMetaData();
- out.add(record);
- meta.add(metadata);
+ List fromMeta = rr.loadFromMetaData(meta);
+ assertEquals(out, fromMeta);
- List fromMeta = rr.loadFromMetaData(meta);
- assertEquals(out, fromMeta);
+ // make sure we don't lose objects just by explicitly resizing
+ int i = 0;
+ int[] nonzeroCount = {5, 10};
- // make sure we don't lose objects just by explicitly resizing
- int i = 0;
- int[] nonzeroCount = {5, 10};
+ ImageTransform transform = new ResizeImageTransform(37, 42);
+ RecordReader rrTransform = new ObjectDetectionRecordReader(42, 37, c, gH, gW, nchw, lp, transform);
+ rrTransform.initialize(new CollectionInputSplit(u));
+ i = 0;
+ while (rrTransform.hasNext()) {
+ List next = rrTransform.next();
+ assertEquals(37, transform.getCurrentImage().getWidth());
+ assertEquals(42, transform.getCurrentImage().getHeight());
+ INDArray labelArray = ((NDArrayWritable) next.get(1)).get();
+ BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
+ assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
+ }
- ImageTransform transform = new ResizeImageTransform(37, 42);
- RecordReader rrTransform = new ObjectDetectionRecordReader(42, 37, c, gH, gW, lp, transform);
- rrTransform.initialize(new CollectionInputSplit(u));
- i = 0;
- while (rrTransform.hasNext()) {
- List next = rrTransform.next();
- assertEquals(37, transform.getCurrentImage().getWidth());
- assertEquals(42, transform.getCurrentImage().getHeight());
- INDArray labelArray = ((NDArrayWritable)next.get(1)).get();
- BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
- assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
- }
+ ImageTransform transform2 = new ResizeImageTransform(1024, 2048);
+ RecordReader rrTransform2 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, nchw, lp, transform2);
+ rrTransform2.initialize(new CollectionInputSplit(u));
+ i = 0;
+ while (rrTransform2.hasNext()) {
+ List next = rrTransform2.next();
+ assertEquals(1024, transform2.getCurrentImage().getWidth());
+ assertEquals(2048, transform2.getCurrentImage().getHeight());
+ INDArray labelArray = ((NDArrayWritable) next.get(1)).get();
+ BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
+ assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
+ }
+
+ //Make sure image flip does not break labels and are correct for new image size dimensions:
+ ImageTransform transform3 = new PipelineImageTransform(
+ new ResizeImageTransform(2048, 4096),
+ new FlipImageTransform(-1)
+ );
+ RecordReader rrTransform3 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, nchw, lp, transform3);
+ rrTransform3.initialize(new CollectionInputSplit(u));
+ i = 0;
+ while (rrTransform3.hasNext()) {
+ List next = rrTransform3.next();
+ INDArray labelArray = ((NDArrayWritable) next.get(1)).get();
+ BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
+ assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
+ }
+
+ //Test that doing a downscale with the native image loader directly instead of a transform does not cause an exception:
+ ImageTransform transform4 = new FlipImageTransform(-1);
+ RecordReader rrTransform4 = new ObjectDetectionRecordReader(128, 128, c, gH, gW, nchw, lp, transform4);
+ rrTransform4.initialize(new CollectionInputSplit(u));
+ i = 0;
+ while (rrTransform4.hasNext()) {
+ List next = rrTransform4.next();
+
+ assertEquals((int) origW[i], transform4.getCurrentImage().getWidth());
+ assertEquals((int) origH[i], transform4.getCurrentImage().getHeight());
+
+ INDArray labelArray = ((NDArrayWritable) next.get(1)).get();
+ BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
+ assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
+ }
- ImageTransform transform2 = new ResizeImageTransform(1024, 2048);
- RecordReader rrTransform2 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, lp, transform2);
- rrTransform2.initialize(new CollectionInputSplit(u));
- i = 0;
- while (rrTransform2.hasNext()) {
- List next = rrTransform2.next();
- assertEquals(1024, transform2.getCurrentImage().getWidth());
- assertEquals(2048, transform2.getCurrentImage().getHeight());
- INDArray labelArray = ((NDArrayWritable)next.get(1)).get();
- BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
- assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
- }
-
- //Make sure image flip does not break labels and are correct for new image size dimensions:
- ImageTransform transform3 = new PipelineImageTransform(
- new ResizeImageTransform(2048, 4096),
- new FlipImageTransform(-1)
- );
- RecordReader rrTransform3 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, lp, transform3);
- rrTransform3.initialize(new CollectionInputSplit(u));
- i = 0;
- while (rrTransform3.hasNext()) {
- List next = rrTransform3.next();
- INDArray labelArray = ((NDArrayWritable)next.get(1)).get();
- BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
- assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
- }
-
- //Test that doing a downscale with the native image loader directly instead of a transform does not cause an exception:
- ImageTransform transform4 = new FlipImageTransform(-1);
- RecordReader rrTransform4 = new ObjectDetectionRecordReader(128, 128, c, gH, gW, lp, transform4);
- rrTransform4.initialize(new CollectionInputSplit(u));
- i = 0;
- while (rrTransform4.hasNext()) {
- List next = rrTransform4.next();
-
- assertEquals((int) origW[i], transform4.getCurrentImage().getWidth());
- assertEquals((int) origH[i], transform4.getCurrentImage().getHeight());
-
- INDArray labelArray = ((NDArrayWritable)next.get(1)).get();
- BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
- assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java
index 5646b6519..47c040c12 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java
@@ -24,9 +24,7 @@ import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
-import org.deeplearning4j.nn.conf.ConvolutionMode;
-import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
-import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
@@ -36,6 +34,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -50,17 +50,28 @@ import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
+import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertTrue;
/**
* @author Alex Black
*/
+@RunWith(Parameterized.class)
public class YoloGradientCheckTests extends BaseDL4JTest {
static {
Nd4j.setDataType(DataType.DOUBLE);
}
+ private CNN2DFormat format;
+ public YoloGradientCheckTests(CNN2DFormat format){
+ this.format = format;
+ }
+ @Parameterized.Parameters(name = "{0}")
+ public static Object[] params(){
+ return CNN2DFormat.values();
+ }
+
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@@ -97,8 +108,14 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345);
- INDArray input = Nd4j.rand(new int[]{mb, depthIn, h, w});
- INDArray labels = yoloLabels(mb, c, h, w);
+ INDArray input, labels;
+ if(format == CNN2DFormat.NCHW){
+ input = Nd4j.rand(DataType.DOUBLE, mb, depthIn, h, w);
+ labels = yoloLabels(mb, c, h, w);
+ } else {
+ input = Nd4j.rand(DataType.DOUBLE, mb, h, w, depthIn);
+ labels = yoloLabels(mb, c, h, w).permute(0,2,3,1);
+ }
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
@@ -112,6 +129,7 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
.layer(new Yolo2OutputLayer.Builder()
.boundingBoxPriors(bbPrior)
.build())
+ .setInputType(InputType.convolutional(h, w, depthIn, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@@ -120,7 +138,18 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
String msg = "testYoloOutputLayer() - minibatch = " + mb + ", w=" + w + ", h=" + h + ", l1=" + l1[i] + ", l2=" + l2[i];
System.out.println(msg);
+ INDArray out = net.output(input);
+ if(format == CNN2DFormat.NCHW){
+ assertArrayEquals(new long[]{mb, yoloDepth, h, w}, out.shape());
+ } else {
+ assertArrayEquals(new long[]{mb, h, w, yoloDepth}, out.shape());
+ }
+
+ net.fit(input, labels);
+
+
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
+ .minAbsoluteError(1e-6)
.labels(labels).subset(true).maxPerParam(100));
assertTrue(msg, gradOK);
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java
index 24bda07f6..6ffb92978 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java
@@ -21,6 +21,7 @@ import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
+import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -80,6 +81,8 @@ public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer {
@JsonDeserialize(using = BoundingBoxesDeserializer.class)
private INDArray boundingBoxes;
+ private CNN2DFormat format = CNN2DFormat.NCHW; //Default for serialization of old formats
+
private Yolo2OutputLayer() {
//No-arg constructor for Jackson JSON
}
@@ -119,7 +122,8 @@ public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer {
@Override
public void setNIn(InputType inputType, boolean override) {
- //No op
+ InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
+ this.format = c.getFormat();
}
@Override
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java
index eb5a4d19e..4d118c62b 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java
@@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.objdetect;
import lombok.*;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
+import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
@@ -110,6 +111,12 @@ public class Yolo2OutputLayer extends AbstractLayer