DL4J/DataVec: Fix Yolo2OutputLayer and ObjectDetectionRecordReader support for NHWC data format (#483)

* Fix Yolo2OutputLayer for NHWC data format

Signed-off-by: Alex Black <blacka101@gmail.com>

* ObjectDetectionRecordReader NHWC support

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-06-05 11:49:02 +10:00 committed by GitHub
parent 45ebd4899c
commit ee3e059b12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 260 additions and 158 deletions

View File

@ -49,7 +49,7 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point;
/** /**
* An image record reader for object detection. * An image record reader for object detection.
* <p> * <p>
* Format of returned values: 4d array, with dimensions [minibatch, 4+C, h, w] * Format of returned values: 4d array, with dimensions [minibatch, 4+C, h, w] (nchw) or [minibatch, h, w, 4+C] (nhwc)
* Where the image is quantized into h x w grid locations. * Where the image is quantized into h x w grid locations.
* <p> * <p>
* Note that this matches the format required for Deeplearning4j's Yolo2OutputLayer * Note that this matches the format required for Deeplearning4j's Yolo2OutputLayer
@ -61,42 +61,67 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
private final int gridW; private final int gridW;
private final int gridH; private final int gridH;
private final ImageObjectLabelProvider labelProvider; private final ImageObjectLabelProvider labelProvider;
private final boolean nchw;
protected Image currentImage; protected Image currentImage;
/** /**
* As per {@link #ObjectDetectionRecordReader(int, int, int, int, int, boolean, ImageObjectLabelProvider)} but hardcoded
* to NCHW format
*/
public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, ImageObjectLabelProvider labelProvider) {
this(height, width, channels, gridH, gridW, true, labelProvider);
}
/**
* Create ObjectDetectionRecordReader with
* *
* @param height Height of the output images * @param height Height of the output images
* @param width Width of the output images * @param width Width of the output images
* @param channels Number of channels for the output images * @param channels Number of channels for the output images
* @param gridH Grid/quantization size (along height dimension) - Y axis * @param gridH Grid/quantization size (along height dimension) - Y axis
* @param gridW Grid/quantization size (along height dimension) - X axis * @param gridW Grid/quantization size (along height dimension) - X axis
* @param nchw If true: return NCHW format labels with array shape [minibatch, 4+C, h, w]; if false, return
* NHWC format labels with array shape [minibatch, h, w, 4+C]
* @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image * @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image
*/ */
public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, ImageObjectLabelProvider labelProvider) { public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, boolean nchw, ImageObjectLabelProvider labelProvider) {
super(height, width, channels, null, null); super(height, width, channels, null, null);
this.gridW = gridW; this.gridW = gridW;
this.gridH = gridH; this.gridH = gridH;
this.nchw = nchw;
this.labelProvider = labelProvider; this.labelProvider = labelProvider;
this.appendLabel = labelProvider != null; this.appendLabel = labelProvider != null;
} }
/** /**
* When imageTransform != null, object is removed if new center is outside of transformed image bounds. * As per {@link #ObjectDetectionRecordReader(int, int, int, int, int, boolean, ImageObjectLabelProvider, ImageTransform)}
* * but hardcoded to NCHW format
* @param height Height of the output images
* @param width Width of the output images
* @param channels Number of channels for the output images
* @param gridH Grid/quantization size (along height dimension) - Y axis
* @param gridW Grid/quantization size (along height dimension) - X axis
* @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image
* @param imageTransform ImageTransform - used to transform image and coordinates
*/ */
public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW,
ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) { ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) {
this(height, width, channels, gridH, gridW, true, labelProvider, imageTransform);
}
/**
* When imageTransform != null, object is removed if new center is outside of transformed image bounds.
*
* @param height Height of the output images
* @param width Width of the output images
* @param channels Number of channels for the output images
* @param gridH Grid/quantization size (along height dimension) - Y axis
* @param gridW Grid/quantization size (along height dimension) - X axis
* @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image
* @param nchw If true: return NCHW format labels with array shape [minibatch, 4+C, h, w]; if false, return
* NHWC format labels with array shape [minibatch, h, w, 4+C]
* @param imageTransform ImageTransform - used to transform image and coordinates
*/
public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, boolean nchw,
ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) {
super(height, width, channels, null, null); super(height, width, channels, null, null);
this.gridW = gridW; this.gridW = gridW;
this.gridH = gridH; this.gridH = gridH;
this.nchw = nchw;
this.labelProvider = labelProvider; this.labelProvider = labelProvider;
this.appendLabel = labelProvider != null; this.appendLabel = labelProvider != null;
this.imageTransform = imageTransform; this.imageTransform = imageTransform;
@ -182,6 +207,10 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
exampleNum++; exampleNum++;
} }
if(!nchw) {
outImg = outImg.permute(0, 2, 3, 1); //NCHW to NHWC
outLabel = outLabel.permute(0, 2, 3, 1);
}
return new NDArrayRecordBatch(Arrays.asList(outImg, outLabel)); return new NDArrayRecordBatch(Arrays.asList(outImg, outLabel));
} }
@ -256,6 +285,8 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
imageLoader = new NativeImageLoader(height, width, channels, imageTransform); imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
} }
Image image = this.imageLoader.asImageMatrix(dataInputStream); Image image = this.imageLoader.asImageMatrix(dataInputStream);
if(!nchw)
image.setImage(image.getImage().permute(0,2,3,1));
Nd4j.getAffinityManager().ensureLocation(image.getImage(), AffinityManager.Location.DEVICE); Nd4j.getAffinityManager().ensureLocation(image.getImage(), AffinityManager.Location.DEVICE);
List<Writable> ret = RecordConverter.toRecord(image.getImage()); List<Writable> ret = RecordConverter.toRecord(image.getImage());
@ -264,6 +295,8 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
int nClasses = labels.size(); int nClasses = labels.size();
INDArray outLabel = Nd4j.create(1, 4 + nClasses, gridH, gridW); INDArray outLabel = Nd4j.create(1, 4 + nClasses, gridH, gridW);
label(image, imageObjectsForPath, outLabel, 0); label(image, imageObjectsForPath, outLabel, 0);
if(!nchw)
outLabel = outLabel.permute(0,2,3,1); //NCHW to NHWC
ret.add(new NDArrayWritable(outLabel)); ret.add(new NDArrayWritable(outLabel));
} }
return ret; return ret;

View File

@ -56,168 +56,179 @@ public class TestObjectDetectionRecordReader {
@Test @Test
public void test() throws Exception { public void test() throws Exception {
ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider(); for(boolean nchw : new boolean[]{true, false}) {
ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider();
File f = testDir.newFolder(); File f = testDir.newFolder();
new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f); new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f);
String path = new File(f, "000012.jpg").getParent(); String path = new File(f, "000012.jpg").getParent();
int h = 32; int h = 32;
int w = 32; int w = 32;
int c = 3; int c = 3;
int gW = 13; int gW = 13;
int gH = 10; int gH = 10;
//Enforce consistent iteration order for tests //Enforce consistent iteration order for tests
URI[] u = new FileSplit(new File(path)).locations(); URI[] u = new FileSplit(new File(path)).locations();
Arrays.sort(u); Arrays.sort(u);
RecordReader rr = new ObjectDetectionRecordReader(h, w, c, gH, gW, lp); RecordReader rr = new ObjectDetectionRecordReader(h, w, c, gH, gW, nchw, lp);
rr.initialize(new CollectionInputSplit(u)); rr.initialize(new CollectionInputSplit(u));
RecordReader imgRR = new ImageRecordReader(h, w, c); RecordReader imgRR = new ImageRecordReader(h, w, c, nchw);
imgRR.initialize(new CollectionInputSplit(u)); imgRR.initialize(new CollectionInputSplit(u));
List<String> labels = rr.getLabels(); List<String> labels = rr.getLabels();
assertEquals(Arrays.asList("car", "cat"), labels); assertEquals(Arrays.asList("car", "cat"), labels);
//000012.jpg - originally 500x333 //000012.jpg - originally 500x333
//000019.jpg - originally 500x375 //000019.jpg - originally 500x375
double[] origW = new double[]{500, 500}; double[] origW = new double[]{500, 500};
double[] origH = new double[]{333, 375}; double[] origH = new double[]{333, 375};
List<List<ImageObject>> l = Arrays.asList( List<List<ImageObject>> l = Arrays.asList(
Collections.singletonList(new ImageObject(156, 97, 351, 270, "car")), Collections.singletonList(new ImageObject(156, 97, 351, 270, "car")),
Arrays.asList(new ImageObject(11, 113, 266, 259, "cat"), new ImageObject(231, 88, 483, 256, "cat")) Arrays.asList(new ImageObject(11, 113, 266, 259, "cat"), new ImageObject(231, 88, 483, 256, "cat"))
); );
for (int idx = 0; idx < 2; idx++) { for (int idx = 0; idx < 2; idx++) {
assertTrue(rr.hasNext()); assertTrue(rr.hasNext());
List<Writable> next = rr.next(); List<Writable> next = rr.next();
List<Writable> nextImgRR = imgRR.next(); List<Writable> nextImgRR = imgRR.next();
//Check features: //Check features:
assertEquals(next.get(0), nextImgRR.get(0)); assertEquals(next.get(0), nextImgRR.get(0));
//Check labels //Check labels
assertEquals(2, next.size()); assertEquals(2, next.size());
assertTrue(next.get(0) instanceof NDArrayWritable); assertTrue(next.get(0) instanceof NDArrayWritable);
assertTrue(next.get(1) instanceof NDArrayWritable); assertTrue(next.get(1) instanceof NDArrayWritable);
List<ImageObject> objects = l.get(idx); List<ImageObject> objects = l.get(idx);
INDArray expLabels = Nd4j.create(1, 4 + 2, gH, gW); INDArray expLabels = Nd4j.create(1, 4 + 2, gH, gW);
for (ImageObject io : objects) { for (ImageObject io : objects) {
double fracImageX1 = io.getX1() / origW[idx]; double fracImageX1 = io.getX1() / origW[idx];
double fracImageY1 = io.getY1() / origH[idx]; double fracImageY1 = io.getY1() / origH[idx];
double fracImageX2 = io.getX2() / origW[idx]; double fracImageX2 = io.getX2() / origW[idx];
double fracImageY2 = io.getY2() / origH[idx]; double fracImageY2 = io.getY2() / origH[idx];
double x1C = (fracImageX1 + fracImageX2) / 2.0; double x1C = (fracImageX1 + fracImageX2) / 2.0;
double y1C = (fracImageY1 + fracImageY2) / 2.0; double y1C = (fracImageY1 + fracImageY2) / 2.0;
int labelGridX = (int) (x1C * gW); int labelGridX = (int) (x1C * gW);
int labelGridY = (int) (y1C * gH); int labelGridY = (int) (y1C * gH);
int labelIdx; int labelIdx;
if (io.getLabel().equals("car")) { if (io.getLabel().equals("car")) {
labelIdx = 4; labelIdx = 4;
} else { } else {
labelIdx = 5; labelIdx = 5;
}
expLabels.putScalar(0, labelIdx, labelGridY, labelGridX, 1.0);
expLabels.putScalar(0, 0, labelGridY, labelGridX, fracImageX1 * gW);
expLabels.putScalar(0, 1, labelGridY, labelGridX, fracImageY1 * gH);
expLabels.putScalar(0, 2, labelGridY, labelGridX, fracImageX2 * gW);
expLabels.putScalar(0, 3, labelGridY, labelGridX, fracImageY2 * gH);
} }
expLabels.putScalar(0, labelIdx, labelGridY, labelGridX, 1.0);
expLabels.putScalar(0, 0, labelGridY, labelGridX, fracImageX1 * gW); INDArray lArr = ((NDArrayWritable) next.get(1)).get();
expLabels.putScalar(0, 1, labelGridY, labelGridX, fracImageY1 * gH); if(nchw) {
expLabels.putScalar(0, 2, labelGridY, labelGridX, fracImageX2 * gW); assertArrayEquals(new long[]{1, 4 + 2, gH, gW}, lArr.shape());
expLabels.putScalar(0, 3, labelGridY, labelGridX, fracImageY2 * gH); } else {
assertArrayEquals(new long[]{1, gH, gW, 4 + 2}, lArr.shape());
}
if(!nchw)
expLabels = expLabels.permute(0,2,3,1); //NCHW to NHWC
assertEquals(expLabels, lArr);
} }
INDArray lArr = ((NDArrayWritable) next.get(1)).get(); rr.reset();
assertArrayEquals(new long[]{1, 4 + 2, gH, gW}, lArr.shape()); Record record = rr.nextRecord();
assertEquals(expLabels, lArr); RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) record.getMetaData();
} assertEquals(new File(path, "000012.jpg"), new File(metadata.getURI()));
assertEquals(3, metadata.getOrigC());
assertEquals((int) origH[0], metadata.getOrigH());
assertEquals((int) origW[0], metadata.getOrigW());
rr.reset(); List<Record> out = new ArrayList<>();
Record record = rr.nextRecord(); List<RecordMetaData> meta = new ArrayList<>();
RecordMetaDataImageURI metadata = (RecordMetaDataImageURI)record.getMetaData(); out.add(record);
assertEquals(new File(path, "000012.jpg"), new File(metadata.getURI())); meta.add(metadata);
assertEquals(3, metadata.getOrigC()); record = rr.nextRecord();
assertEquals((int)origH[0], metadata.getOrigH()); metadata = (RecordMetaDataImageURI) record.getMetaData();
assertEquals((int)origW[0], metadata.getOrigW()); out.add(record);
meta.add(metadata);
List<Record> out = new ArrayList<>(); List<Record> fromMeta = rr.loadFromMetaData(meta);
List<RecordMetaData> meta = new ArrayList<>(); assertEquals(out, fromMeta);
out.add(record);
meta.add(metadata);
record = rr.nextRecord();
metadata = (RecordMetaDataImageURI)record.getMetaData();
out.add(record);
meta.add(metadata);
List<Record> fromMeta = rr.loadFromMetaData(meta); // make sure we don't lose objects just by explicitly resizing
assertEquals(out, fromMeta); int i = 0;
int[] nonzeroCount = {5, 10};
// make sure we don't lose objects just by explicitly resizing ImageTransform transform = new ResizeImageTransform(37, 42);
int i = 0; RecordReader rrTransform = new ObjectDetectionRecordReader(42, 37, c, gH, gW, nchw, lp, transform);
int[] nonzeroCount = {5, 10}; rrTransform.initialize(new CollectionInputSplit(u));
i = 0;
while (rrTransform.hasNext()) {
List<Writable> next = rrTransform.next();
assertEquals(37, transform.getCurrentImage().getWidth());
assertEquals(42, transform.getCurrentImage().getHeight());
INDArray labelArray = ((NDArrayWritable) next.get(1)).get();
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
}
ImageTransform transform = new ResizeImageTransform(37, 42); ImageTransform transform2 = new ResizeImageTransform(1024, 2048);
RecordReader rrTransform = new ObjectDetectionRecordReader(42, 37, c, gH, gW, lp, transform); RecordReader rrTransform2 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, nchw, lp, transform2);
rrTransform.initialize(new CollectionInputSplit(u)); rrTransform2.initialize(new CollectionInputSplit(u));
i = 0; i = 0;
while (rrTransform.hasNext()) { while (rrTransform2.hasNext()) {
List<Writable> next = rrTransform.next(); List<Writable> next = rrTransform2.next();
assertEquals(37, transform.getCurrentImage().getWidth()); assertEquals(1024, transform2.getCurrentImage().getWidth());
assertEquals(42, transform.getCurrentImage().getHeight()); assertEquals(2048, transform2.getCurrentImage().getHeight());
INDArray labelArray = ((NDArrayWritable)next.get(1)).get(); INDArray labelArray = ((NDArrayWritable) next.get(1)).get();
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
} }
ImageTransform transform2 = new ResizeImageTransform(1024, 2048); //Make sure image flip does not break labels and are correct for new image size dimensions:
RecordReader rrTransform2 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, lp, transform2); ImageTransform transform3 = new PipelineImageTransform(
rrTransform2.initialize(new CollectionInputSplit(u)); new ResizeImageTransform(2048, 4096),
i = 0; new FlipImageTransform(-1)
while (rrTransform2.hasNext()) { );
List<Writable> next = rrTransform2.next(); RecordReader rrTransform3 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, nchw, lp, transform3);
assertEquals(1024, transform2.getCurrentImage().getWidth()); rrTransform3.initialize(new CollectionInputSplit(u));
assertEquals(2048, transform2.getCurrentImage().getHeight()); i = 0;
INDArray labelArray = ((NDArrayWritable)next.get(1)).get(); while (rrTransform3.hasNext()) {
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); List<Writable> next = rrTransform3.next();
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); INDArray labelArray = ((NDArrayWritable) next.get(1)).get();
} BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
}
//Make sure image flip does not break labels and are correct for new image size dimensions: //Test that doing a downscale with the native image loader directly instead of a transform does not cause an exception:
ImageTransform transform3 = new PipelineImageTransform( ImageTransform transform4 = new FlipImageTransform(-1);
new ResizeImageTransform(2048, 4096), RecordReader rrTransform4 = new ObjectDetectionRecordReader(128, 128, c, gH, gW, nchw, lp, transform4);
new FlipImageTransform(-1) rrTransform4.initialize(new CollectionInputSplit(u));
); i = 0;
RecordReader rrTransform3 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, lp, transform3); while (rrTransform4.hasNext()) {
rrTransform3.initialize(new CollectionInputSplit(u)); List<Writable> next = rrTransform4.next();
i = 0;
while (rrTransform3.hasNext()) {
List<Writable> next = rrTransform3.next();
INDArray labelArray = ((NDArrayWritable)next.get(1)).get();
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
}
//Test that doing a downscale with the native image loader directly instead of a transform does not cause an exception: assertEquals((int) origW[i], transform4.getCurrentImage().getWidth());
ImageTransform transform4 = new FlipImageTransform(-1); assertEquals((int) origH[i], transform4.getCurrentImage().getHeight());
RecordReader rrTransform4 = new ObjectDetectionRecordReader(128, 128, c, gH, gW, lp, transform4);
rrTransform4.initialize(new CollectionInputSplit(u));
i = 0;
while (rrTransform4.hasNext()) {
List<Writable> next = rrTransform4.next();
assertEquals((int) origW[i], transform4.getCurrentImage().getWidth()); INDArray labelArray = ((NDArrayWritable) next.get(1)).get();
assertEquals((int) origH[i], transform4.getCurrentImage().getHeight()); BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
}
INDArray labelArray = ((NDArrayWritable)next.get(1)).get();
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
} }
} }

View File

@ -24,9 +24,7 @@ import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution; import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
@ -36,6 +34,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -50,17 +50,28 @@ import java.io.File;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.io.InputStream; import java.io.InputStream;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
/** /**
* @author Alex Black * @author Alex Black
*/ */
@RunWith(Parameterized.class)
public class YoloGradientCheckTests extends BaseDL4JTest { public class YoloGradientCheckTests extends BaseDL4JTest {
static { static {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
private CNN2DFormat format;
public YoloGradientCheckTests(CNN2DFormat format){
this.format = format;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return CNN2DFormat.values();
}
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();
@ -97,8 +108,14 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray input = Nd4j.rand(new int[]{mb, depthIn, h, w}); INDArray input, labels;
INDArray labels = yoloLabels(mb, c, h, w); if(format == CNN2DFormat.NCHW){
input = Nd4j.rand(DataType.DOUBLE, mb, depthIn, h, w);
labels = yoloLabels(mb, c, h, w);
} else {
input = Nd4j.rand(DataType.DOUBLE, mb, h, w, depthIn);
labels = yoloLabels(mb, c, h, w).permute(0,2,3,1);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE) .dataType(DataType.DOUBLE)
@ -112,6 +129,7 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
.layer(new Yolo2OutputLayer.Builder() .layer(new Yolo2OutputLayer.Builder()
.boundingBoxPriors(bbPrior) .boundingBoxPriors(bbPrior)
.build()) .build())
.setInputType(InputType.convolutional(h, w, depthIn, format))
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -120,7 +138,18 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
String msg = "testYoloOutputLayer() - minibatch = " + mb + ", w=" + w + ", h=" + h + ", l1=" + l1[i] + ", l2=" + l2[i]; String msg = "testYoloOutputLayer() - minibatch = " + mb + ", w=" + w + ", h=" + h + ", l1=" + l1[i] + ", l2=" + l2[i];
System.out.println(msg); System.out.println(msg);
INDArray out = net.output(input);
if(format == CNN2DFormat.NCHW){
assertArrayEquals(new long[]{mb, yoloDepth, h, w}, out.shape());
} else {
assertArrayEquals(new long[]{mb, h, w, yoloDepth}, out.shape());
}
net.fit(input, labels);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
.minAbsoluteError(1e-6)
.labels(labels).subset(true).maxPerParam(100)); .labels(labels).subset(true).maxPerParam(100));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);

View File

@ -21,6 +21,7 @@ import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -80,6 +81,8 @@ public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer {
@JsonDeserialize(using = BoundingBoxesDeserializer.class) @JsonDeserialize(using = BoundingBoxesDeserializer.class)
private INDArray boundingBoxes; private INDArray boundingBoxes;
private CNN2DFormat format = CNN2DFormat.NCHW; //Default for serialization of old formats
private Yolo2OutputLayer() { private Yolo2OutputLayer() {
//No-arg constructor for Jackson JSON //No-arg constructor for Jackson JSON
} }
@ -119,7 +122,8 @@ public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer {
@Override @Override
public void setNIn(InputType inputType, boolean override) { public void setNIn(InputType inputType, boolean override) {
//No op InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
this.format = c.getFormat();
} }
@Override @Override

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.objdetect;
import lombok.*; import lombok.*;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
@ -110,6 +111,12 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
Preconditions.checkState(labels.rank() == 4, "Expected rank 4 labels array with shape [minibatch, 4+numClasses, h, w]" + Preconditions.checkState(labels.rank() == 4, "Expected rank 4 labels array with shape [minibatch, 4+numClasses, h, w]" +
" but got rank %s labels array with shape %s", labels.rank(), labels.shape()); " but got rank %s labels array with shape %s", labels.rank(), labels.shape());
boolean nchw = layerConf().getFormat() == CNN2DFormat.NCHW;
INDArray input = nchw ? this.input : this.input.permute(0,3,1,2); //NHWC to NCHW
INDArray labels = this.labels.castTo(input.dataType()); //Ensure correct dtype (same as params); no-op if already correct dtype
if(!nchw)
labels = labels.permute(0,3,1,2); //NHWC to NCHW
double lambdaCoord = layerConf().getLambdaCoord(); double lambdaCoord = layerConf().getLambdaCoord();
double lambdaNoObj = layerConf().getLambdaNoObj(); double lambdaNoObj = layerConf().getLambdaNoObj();
@ -119,7 +126,7 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
int b = (int) layerConf().getBoundingBoxes().size(0); int b = (int) layerConf().getBoundingBoxes().size(0);
int c = (int) labels.size(1)-4; int c = (int) labels.size(1)-4;
INDArray labels = this.labels.castTo(input.dataType()); //Ensure correct dtype (same as params); no-op if already correct dtype
//Various shape arrays, to reuse //Various shape arrays, to reuse
long[] nhw = new long[]{mb, h, w}; long[] nhw = new long[]{mb, h, w};
@ -380,13 +387,17 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
epsWH.addi(dLc_din_wh); epsWH.addi(dLc_din_wh);
epsXY.addi(dLc_din_xy); epsXY.addi(dLc_din_xy);
if(!nchw)
epsOut = epsOut.permute(0,2,3,1); //NCHW to NHWC
return epsOut; return epsOut;
} }
@Override @Override
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
assertInputSet(false); assertInputSet(false);
return YoloUtils.activate(layerConf().getBoundingBoxes(), input, workspaceMgr); boolean nchw = layerConf().getFormat() == CNN2DFormat.NCHW;
return YoloUtils.activate(layerConf().getBoundingBoxes(), input, nchw, workspaceMgr);
} }
@Override @Override

View File

@ -39,12 +39,23 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
*/ */
public class YoloUtils { public class YoloUtils {
/** Essentially: just apply activation functions... */ /** Essentially: just apply activation functions... For NCHW format. For NCHW format, use one of the other activate methods */
public static INDArray activate(INDArray boundingBoxPriors, INDArray input) { public static INDArray activate(INDArray boundingBoxPriors, INDArray input) {
return activate(boundingBoxPriors, input, LayerWorkspaceMgr.noWorkspaces()); return activate(boundingBoxPriors, input, true);
} }
public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, LayerWorkspaceMgr layerWorkspaceMgr){ public static INDArray activate(INDArray boundingBoxPriors, INDArray input, boolean nchw) {
return activate(boundingBoxPriors, input, nchw, LayerWorkspaceMgr.noWorkspaces());
}
public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, LayerWorkspaceMgr layerWorkspaceMgr) {
return activate(boundingBoxPriors, input, true, layerWorkspaceMgr);
}
public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, boolean nchw, LayerWorkspaceMgr layerWorkspaceMgr){
if(!nchw)
input = input.permute(0,3,1,2); //NHWC to NCHW
long mb = input.size(0); long mb = input.size(0);
long h = input.size(2); long h = input.size(2);
long w = input.size(3); long w = input.size(3);
@ -83,6 +94,9 @@ public class YoloUtils {
INDArray outputClasses = output5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [minibatch, C, H, W] INDArray outputClasses = output5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [minibatch, C, H, W]
outputClasses.assign(postSoftmax5d); outputClasses.assign(postSoftmax5d);
if(!nchw)
output = output.permute(0,2,3,1); //NCHW to NHWC
return output; return output;
} }