DL4J/DataVec: Fix Yolo2OutputLayer and ObjectDetectionRecordReader support for NHWC data format (#483)
* Fix Yolo2OutputLayer for NHWC data format Signed-off-by: Alex Black <blacka101@gmail.com> * ObjectDetectionRecordReader NHWC support Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
45ebd4899c
commit
ee3e059b12
|
@ -49,7 +49,7 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point;
|
||||||
/**
|
/**
|
||||||
* An image record reader for object detection.
|
* An image record reader for object detection.
|
||||||
* <p>
|
* <p>
|
||||||
* Format of returned values: 4d array, with dimensions [minibatch, 4+C, h, w]
|
* Format of returned values: 4d array, with dimensions [minibatch, 4+C, h, w] (nchw) or [minibatch, h, w, 4+C] (nhwc)
|
||||||
* Where the image is quantized into h x w grid locations.
|
* Where the image is quantized into h x w grid locations.
|
||||||
* <p>
|
* <p>
|
||||||
* Note that this matches the format required for Deeplearning4j's Yolo2OutputLayer
|
* Note that this matches the format required for Deeplearning4j's Yolo2OutputLayer
|
||||||
|
@ -61,26 +61,48 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
|
||||||
private final int gridW;
|
private final int gridW;
|
||||||
private final int gridH;
|
private final int gridH;
|
||||||
private final ImageObjectLabelProvider labelProvider;
|
private final ImageObjectLabelProvider labelProvider;
|
||||||
|
private final boolean nchw;
|
||||||
|
|
||||||
protected Image currentImage;
|
protected Image currentImage;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
* As per {@link #ObjectDetectionRecordReader(int, int, int, int, int, boolean, ImageObjectLabelProvider)} but hardcoded
|
||||||
|
* to NCHW format
|
||||||
|
*/
|
||||||
|
public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, ImageObjectLabelProvider labelProvider) {
|
||||||
|
this(height, width, channels, gridH, gridW, true, labelProvider);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create ObjectDetectionRecordReader with
|
||||||
*
|
*
|
||||||
* @param height Height of the output images
|
* @param height Height of the output images
|
||||||
* @param width Width of the output images
|
* @param width Width of the output images
|
||||||
* @param channels Number of channels for the output images
|
* @param channels Number of channels for the output images
|
||||||
* @param gridH Grid/quantization size (along height dimension) - Y axis
|
* @param gridH Grid/quantization size (along height dimension) - Y axis
|
||||||
* @param gridW Grid/quantization size (along height dimension) - X axis
|
* @param gridW Grid/quantization size (along height dimension) - X axis
|
||||||
|
* @param nchw If true: return NCHW format labels with array shape [minibatch, 4+C, h, w]; if false, return
|
||||||
|
* NHWC format labels with array shape [minibatch, h, w, 4+C]
|
||||||
* @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image
|
* @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image
|
||||||
*/
|
*/
|
||||||
public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, ImageObjectLabelProvider labelProvider) {
|
public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, boolean nchw, ImageObjectLabelProvider labelProvider) {
|
||||||
super(height, width, channels, null, null);
|
super(height, width, channels, null, null);
|
||||||
this.gridW = gridW;
|
this.gridW = gridW;
|
||||||
this.gridH = gridH;
|
this.gridH = gridH;
|
||||||
|
this.nchw = nchw;
|
||||||
this.labelProvider = labelProvider;
|
this.labelProvider = labelProvider;
|
||||||
this.appendLabel = labelProvider != null;
|
this.appendLabel = labelProvider != null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* As per {@link #ObjectDetectionRecordReader(int, int, int, int, int, boolean, ImageObjectLabelProvider, ImageTransform)}
|
||||||
|
* but hardcoded to NCHW format
|
||||||
|
*/
|
||||||
|
public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW,
|
||||||
|
ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) {
|
||||||
|
this(height, width, channels, gridH, gridW, true, labelProvider, imageTransform);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* When imageTransform != null, object is removed if new center is outside of transformed image bounds.
|
* When imageTransform != null, object is removed if new center is outside of transformed image bounds.
|
||||||
*
|
*
|
||||||
|
@ -90,13 +112,16 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
|
||||||
* @param gridH Grid/quantization size (along height dimension) - Y axis
|
* @param gridH Grid/quantization size (along height dimension) - Y axis
|
||||||
* @param gridW Grid/quantization size (along height dimension) - X axis
|
* @param gridW Grid/quantization size (along height dimension) - X axis
|
||||||
* @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image
|
* @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image
|
||||||
|
* @param nchw If true: return NCHW format labels with array shape [minibatch, 4+C, h, w]; if false, return
|
||||||
|
* NHWC format labels with array shape [minibatch, h, w, 4+C]
|
||||||
* @param imageTransform ImageTransform - used to transform image and coordinates
|
* @param imageTransform ImageTransform - used to transform image and coordinates
|
||||||
*/
|
*/
|
||||||
public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW,
|
public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, boolean nchw,
|
||||||
ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) {
|
ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) {
|
||||||
super(height, width, channels, null, null);
|
super(height, width, channels, null, null);
|
||||||
this.gridW = gridW;
|
this.gridW = gridW;
|
||||||
this.gridH = gridH;
|
this.gridH = gridH;
|
||||||
|
this.nchw = nchw;
|
||||||
this.labelProvider = labelProvider;
|
this.labelProvider = labelProvider;
|
||||||
this.appendLabel = labelProvider != null;
|
this.appendLabel = labelProvider != null;
|
||||||
this.imageTransform = imageTransform;
|
this.imageTransform = imageTransform;
|
||||||
|
@ -182,6 +207,10 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
|
||||||
exampleNum++;
|
exampleNum++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(!nchw) {
|
||||||
|
outImg = outImg.permute(0, 2, 3, 1); //NCHW to NHWC
|
||||||
|
outLabel = outLabel.permute(0, 2, 3, 1);
|
||||||
|
}
|
||||||
return new NDArrayRecordBatch(Arrays.asList(outImg, outLabel));
|
return new NDArrayRecordBatch(Arrays.asList(outImg, outLabel));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,6 +285,8 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
|
||||||
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
|
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
|
||||||
}
|
}
|
||||||
Image image = this.imageLoader.asImageMatrix(dataInputStream);
|
Image image = this.imageLoader.asImageMatrix(dataInputStream);
|
||||||
|
if(!nchw)
|
||||||
|
image.setImage(image.getImage().permute(0,2,3,1));
|
||||||
Nd4j.getAffinityManager().ensureLocation(image.getImage(), AffinityManager.Location.DEVICE);
|
Nd4j.getAffinityManager().ensureLocation(image.getImage(), AffinityManager.Location.DEVICE);
|
||||||
|
|
||||||
List<Writable> ret = RecordConverter.toRecord(image.getImage());
|
List<Writable> ret = RecordConverter.toRecord(image.getImage());
|
||||||
|
@ -264,6 +295,8 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader {
|
||||||
int nClasses = labels.size();
|
int nClasses = labels.size();
|
||||||
INDArray outLabel = Nd4j.create(1, 4 + nClasses, gridH, gridW);
|
INDArray outLabel = Nd4j.create(1, 4 + nClasses, gridH, gridW);
|
||||||
label(image, imageObjectsForPath, outLabel, 0);
|
label(image, imageObjectsForPath, outLabel, 0);
|
||||||
|
if(!nchw)
|
||||||
|
outLabel = outLabel.permute(0,2,3,1); //NCHW to NHWC
|
||||||
ret.add(new NDArrayWritable(outLabel));
|
ret.add(new NDArrayWritable(outLabel));
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
|
|
|
@ -56,6 +56,7 @@ public class TestObjectDetectionRecordReader {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test() throws Exception {
|
public void test() throws Exception {
|
||||||
|
for(boolean nchw : new boolean[]{true, false}) {
|
||||||
ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider();
|
ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider();
|
||||||
|
|
||||||
File f = testDir.newFolder();
|
File f = testDir.newFolder();
|
||||||
|
@ -73,10 +74,10 @@ public class TestObjectDetectionRecordReader {
|
||||||
URI[] u = new FileSplit(new File(path)).locations();
|
URI[] u = new FileSplit(new File(path)).locations();
|
||||||
Arrays.sort(u);
|
Arrays.sort(u);
|
||||||
|
|
||||||
RecordReader rr = new ObjectDetectionRecordReader(h, w, c, gH, gW, lp);
|
RecordReader rr = new ObjectDetectionRecordReader(h, w, c, gH, gW, nchw, lp);
|
||||||
rr.initialize(new CollectionInputSplit(u));
|
rr.initialize(new CollectionInputSplit(u));
|
||||||
|
|
||||||
RecordReader imgRR = new ImageRecordReader(h, w, c);
|
RecordReader imgRR = new ImageRecordReader(h, w, c, nchw);
|
||||||
imgRR.initialize(new CollectionInputSplit(u));
|
imgRR.initialize(new CollectionInputSplit(u));
|
||||||
|
|
||||||
List<String> labels = rr.getLabels();
|
List<String> labels = rr.getLabels();
|
||||||
|
@ -135,24 +136,32 @@ public class TestObjectDetectionRecordReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray lArr = ((NDArrayWritable) next.get(1)).get();
|
INDArray lArr = ((NDArrayWritable) next.get(1)).get();
|
||||||
|
if(nchw) {
|
||||||
assertArrayEquals(new long[]{1, 4 + 2, gH, gW}, lArr.shape());
|
assertArrayEquals(new long[]{1, 4 + 2, gH, gW}, lArr.shape());
|
||||||
|
} else {
|
||||||
|
assertArrayEquals(new long[]{1, gH, gW, 4 + 2}, lArr.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!nchw)
|
||||||
|
expLabels = expLabels.permute(0,2,3,1); //NCHW to NHWC
|
||||||
|
|
||||||
assertEquals(expLabels, lArr);
|
assertEquals(expLabels, lArr);
|
||||||
}
|
}
|
||||||
|
|
||||||
rr.reset();
|
rr.reset();
|
||||||
Record record = rr.nextRecord();
|
Record record = rr.nextRecord();
|
||||||
RecordMetaDataImageURI metadata = (RecordMetaDataImageURI)record.getMetaData();
|
RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) record.getMetaData();
|
||||||
assertEquals(new File(path, "000012.jpg"), new File(metadata.getURI()));
|
assertEquals(new File(path, "000012.jpg"), new File(metadata.getURI()));
|
||||||
assertEquals(3, metadata.getOrigC());
|
assertEquals(3, metadata.getOrigC());
|
||||||
assertEquals((int)origH[0], metadata.getOrigH());
|
assertEquals((int) origH[0], metadata.getOrigH());
|
||||||
assertEquals((int)origW[0], metadata.getOrigW());
|
assertEquals((int) origW[0], metadata.getOrigW());
|
||||||
|
|
||||||
List<Record> out = new ArrayList<>();
|
List<Record> out = new ArrayList<>();
|
||||||
List<RecordMetaData> meta = new ArrayList<>();
|
List<RecordMetaData> meta = new ArrayList<>();
|
||||||
out.add(record);
|
out.add(record);
|
||||||
meta.add(metadata);
|
meta.add(metadata);
|
||||||
record = rr.nextRecord();
|
record = rr.nextRecord();
|
||||||
metadata = (RecordMetaDataImageURI)record.getMetaData();
|
metadata = (RecordMetaDataImageURI) record.getMetaData();
|
||||||
out.add(record);
|
out.add(record);
|
||||||
meta.add(metadata);
|
meta.add(metadata);
|
||||||
|
|
||||||
|
@ -164,27 +173,27 @@ public class TestObjectDetectionRecordReader {
|
||||||
int[] nonzeroCount = {5, 10};
|
int[] nonzeroCount = {5, 10};
|
||||||
|
|
||||||
ImageTransform transform = new ResizeImageTransform(37, 42);
|
ImageTransform transform = new ResizeImageTransform(37, 42);
|
||||||
RecordReader rrTransform = new ObjectDetectionRecordReader(42, 37, c, gH, gW, lp, transform);
|
RecordReader rrTransform = new ObjectDetectionRecordReader(42, 37, c, gH, gW, nchw, lp, transform);
|
||||||
rrTransform.initialize(new CollectionInputSplit(u));
|
rrTransform.initialize(new CollectionInputSplit(u));
|
||||||
i = 0;
|
i = 0;
|
||||||
while (rrTransform.hasNext()) {
|
while (rrTransform.hasNext()) {
|
||||||
List<Writable> next = rrTransform.next();
|
List<Writable> next = rrTransform.next();
|
||||||
assertEquals(37, transform.getCurrentImage().getWidth());
|
assertEquals(37, transform.getCurrentImage().getWidth());
|
||||||
assertEquals(42, transform.getCurrentImage().getHeight());
|
assertEquals(42, transform.getCurrentImage().getHeight());
|
||||||
INDArray labelArray = ((NDArrayWritable)next.get(1)).get();
|
INDArray labelArray = ((NDArrayWritable) next.get(1)).get();
|
||||||
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
|
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
|
||||||
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
|
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
ImageTransform transform2 = new ResizeImageTransform(1024, 2048);
|
ImageTransform transform2 = new ResizeImageTransform(1024, 2048);
|
||||||
RecordReader rrTransform2 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, lp, transform2);
|
RecordReader rrTransform2 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, nchw, lp, transform2);
|
||||||
rrTransform2.initialize(new CollectionInputSplit(u));
|
rrTransform2.initialize(new CollectionInputSplit(u));
|
||||||
i = 0;
|
i = 0;
|
||||||
while (rrTransform2.hasNext()) {
|
while (rrTransform2.hasNext()) {
|
||||||
List<Writable> next = rrTransform2.next();
|
List<Writable> next = rrTransform2.next();
|
||||||
assertEquals(1024, transform2.getCurrentImage().getWidth());
|
assertEquals(1024, transform2.getCurrentImage().getWidth());
|
||||||
assertEquals(2048, transform2.getCurrentImage().getHeight());
|
assertEquals(2048, transform2.getCurrentImage().getHeight());
|
||||||
INDArray labelArray = ((NDArrayWritable)next.get(1)).get();
|
INDArray labelArray = ((NDArrayWritable) next.get(1)).get();
|
||||||
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
|
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
|
||||||
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
|
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
|
||||||
}
|
}
|
||||||
|
@ -194,19 +203,19 @@ public class TestObjectDetectionRecordReader {
|
||||||
new ResizeImageTransform(2048, 4096),
|
new ResizeImageTransform(2048, 4096),
|
||||||
new FlipImageTransform(-1)
|
new FlipImageTransform(-1)
|
||||||
);
|
);
|
||||||
RecordReader rrTransform3 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, lp, transform3);
|
RecordReader rrTransform3 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, nchw, lp, transform3);
|
||||||
rrTransform3.initialize(new CollectionInputSplit(u));
|
rrTransform3.initialize(new CollectionInputSplit(u));
|
||||||
i = 0;
|
i = 0;
|
||||||
while (rrTransform3.hasNext()) {
|
while (rrTransform3.hasNext()) {
|
||||||
List<Writable> next = rrTransform3.next();
|
List<Writable> next = rrTransform3.next();
|
||||||
INDArray labelArray = ((NDArrayWritable)next.get(1)).get();
|
INDArray labelArray = ((NDArrayWritable) next.get(1)).get();
|
||||||
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
|
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
|
||||||
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
|
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
//Test that doing a downscale with the native image loader directly instead of a transform does not cause an exception:
|
//Test that doing a downscale with the native image loader directly instead of a transform does not cause an exception:
|
||||||
ImageTransform transform4 = new FlipImageTransform(-1);
|
ImageTransform transform4 = new FlipImageTransform(-1);
|
||||||
RecordReader rrTransform4 = new ObjectDetectionRecordReader(128, 128, c, gH, gW, lp, transform4);
|
RecordReader rrTransform4 = new ObjectDetectionRecordReader(128, 128, c, gH, gW, nchw, lp, transform4);
|
||||||
rrTransform4.initialize(new CollectionInputSplit(u));
|
rrTransform4.initialize(new CollectionInputSplit(u));
|
||||||
i = 0;
|
i = 0;
|
||||||
while (rrTransform4.hasNext()) {
|
while (rrTransform4.hasNext()) {
|
||||||
|
@ -215,10 +224,12 @@ public class TestObjectDetectionRecordReader {
|
||||||
assertEquals((int) origW[i], transform4.getCurrentImage().getWidth());
|
assertEquals((int) origW[i], transform4.getCurrentImage().getWidth());
|
||||||
assertEquals((int) origH[i], transform4.getCurrentImage().getHeight());
|
assertEquals((int) origH[i], transform4.getCurrentImage().getHeight());
|
||||||
|
|
||||||
INDArray labelArray = ((NDArrayWritable)next.get(1)).get();
|
INDArray labelArray = ((NDArrayWritable) next.get(1)).get();
|
||||||
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
|
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
|
||||||
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
|
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//2 images: 000012.jpg and 000019.jpg
|
//2 images: 000012.jpg and 000019.jpg
|
||||||
|
|
|
@ -24,9 +24,7 @@ import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.*;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
|
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
|
@ -36,6 +34,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -50,17 +50,28 @@ import java.io.File;
|
||||||
import java.io.FileOutputStream;
|
import java.io.FileOutputStream;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
public class YoloGradientCheckTests extends BaseDL4JTest {
|
public class YoloGradientCheckTests extends BaseDL4JTest {
|
||||||
|
|
||||||
static {
|
static {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private CNN2DFormat format;
|
||||||
|
public YoloGradientCheckTests(CNN2DFormat format){
|
||||||
|
this.format = format;
|
||||||
|
}
|
||||||
|
@Parameterized.Parameters(name = "{0}")
|
||||||
|
public static Object[] params(){
|
||||||
|
return CNN2DFormat.values();
|
||||||
|
}
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@ -97,8 +108,14 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
INDArray input = Nd4j.rand(new int[]{mb, depthIn, h, w});
|
INDArray input, labels;
|
||||||
INDArray labels = yoloLabels(mb, c, h, w);
|
if(format == CNN2DFormat.NCHW){
|
||||||
|
input = Nd4j.rand(DataType.DOUBLE, mb, depthIn, h, w);
|
||||||
|
labels = yoloLabels(mb, c, h, w);
|
||||||
|
} else {
|
||||||
|
input = Nd4j.rand(DataType.DOUBLE, mb, h, w, depthIn);
|
||||||
|
labels = yoloLabels(mb, c, h, w).permute(0,2,3,1);
|
||||||
|
}
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
|
@ -112,6 +129,7 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
|
||||||
.layer(new Yolo2OutputLayer.Builder()
|
.layer(new Yolo2OutputLayer.Builder()
|
||||||
.boundingBoxPriors(bbPrior)
|
.boundingBoxPriors(bbPrior)
|
||||||
.build())
|
.build())
|
||||||
|
.setInputType(InputType.convolutional(h, w, depthIn, format))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
@ -120,7 +138,18 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
|
||||||
String msg = "testYoloOutputLayer() - minibatch = " + mb + ", w=" + w + ", h=" + h + ", l1=" + l1[i] + ", l2=" + l2[i];
|
String msg = "testYoloOutputLayer() - minibatch = " + mb + ", w=" + w + ", h=" + h + ", l1=" + l1[i] + ", l2=" + l2[i];
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
|
|
||||||
|
INDArray out = net.output(input);
|
||||||
|
if(format == CNN2DFormat.NCHW){
|
||||||
|
assertArrayEquals(new long[]{mb, yoloDepth, h, w}, out.shape());
|
||||||
|
} else {
|
||||||
|
assertArrayEquals(new long[]{mb, h, w, yoloDepth}, out.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
net.fit(input, labels);
|
||||||
|
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
|
||||||
|
.minAbsoluteError(1e-6)
|
||||||
.labels(labels).subset(true).maxPerParam(100));
|
.labels(labels).subset(true).maxPerParam(100));
|
||||||
|
|
||||||
assertTrue(msg, gradOK);
|
assertTrue(msg, gradOK);
|
||||||
|
|
|
@ -21,6 +21,7 @@ import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -80,6 +81,8 @@ public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer {
|
||||||
@JsonDeserialize(using = BoundingBoxesDeserializer.class)
|
@JsonDeserialize(using = BoundingBoxesDeserializer.class)
|
||||||
private INDArray boundingBoxes;
|
private INDArray boundingBoxes;
|
||||||
|
|
||||||
|
private CNN2DFormat format = CNN2DFormat.NCHW; //Default for serialization of old formats
|
||||||
|
|
||||||
private Yolo2OutputLayer() {
|
private Yolo2OutputLayer() {
|
||||||
//No-arg constructor for Jackson JSON
|
//No-arg constructor for Jackson JSON
|
||||||
}
|
}
|
||||||
|
@ -119,7 +122,8 @@ public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setNIn(InputType inputType, boolean override) {
|
public void setNIn(InputType inputType, boolean override) {
|
||||||
//No op
|
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||||
|
this.format = c.getFormat();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.objdetect;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
@ -110,6 +111,12 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
||||||
Preconditions.checkState(labels.rank() == 4, "Expected rank 4 labels array with shape [minibatch, 4+numClasses, h, w]" +
|
Preconditions.checkState(labels.rank() == 4, "Expected rank 4 labels array with shape [minibatch, 4+numClasses, h, w]" +
|
||||||
" but got rank %s labels array with shape %s", labels.rank(), labels.shape());
|
" but got rank %s labels array with shape %s", labels.rank(), labels.shape());
|
||||||
|
|
||||||
|
boolean nchw = layerConf().getFormat() == CNN2DFormat.NCHW;
|
||||||
|
INDArray input = nchw ? this.input : this.input.permute(0,3,1,2); //NHWC to NCHW
|
||||||
|
INDArray labels = this.labels.castTo(input.dataType()); //Ensure correct dtype (same as params); no-op if already correct dtype
|
||||||
|
if(!nchw)
|
||||||
|
labels = labels.permute(0,3,1,2); //NHWC to NCHW
|
||||||
|
|
||||||
double lambdaCoord = layerConf().getLambdaCoord();
|
double lambdaCoord = layerConf().getLambdaCoord();
|
||||||
double lambdaNoObj = layerConf().getLambdaNoObj();
|
double lambdaNoObj = layerConf().getLambdaNoObj();
|
||||||
|
|
||||||
|
@ -119,7 +126,7 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
||||||
int b = (int) layerConf().getBoundingBoxes().size(0);
|
int b = (int) layerConf().getBoundingBoxes().size(0);
|
||||||
int c = (int) labels.size(1)-4;
|
int c = (int) labels.size(1)-4;
|
||||||
|
|
||||||
INDArray labels = this.labels.castTo(input.dataType()); //Ensure correct dtype (same as params); no-op if already correct dtype
|
|
||||||
|
|
||||||
//Various shape arrays, to reuse
|
//Various shape arrays, to reuse
|
||||||
long[] nhw = new long[]{mb, h, w};
|
long[] nhw = new long[]{mb, h, w};
|
||||||
|
@ -380,13 +387,17 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
||||||
epsWH.addi(dLc_din_wh);
|
epsWH.addi(dLc_din_wh);
|
||||||
epsXY.addi(dLc_din_xy);
|
epsXY.addi(dLc_din_xy);
|
||||||
|
|
||||||
|
if(!nchw)
|
||||||
|
epsOut = epsOut.permute(0,2,3,1); //NCHW to NHWC
|
||||||
|
|
||||||
return epsOut;
|
return epsOut;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
|
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
|
||||||
assertInputSet(false);
|
assertInputSet(false);
|
||||||
return YoloUtils.activate(layerConf().getBoundingBoxes(), input, workspaceMgr);
|
boolean nchw = layerConf().getFormat() == CNN2DFormat.NCHW;
|
||||||
|
return YoloUtils.activate(layerConf().getBoundingBoxes(), input, nchw, workspaceMgr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -39,12 +39,23 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
|
||||||
*/
|
*/
|
||||||
public class YoloUtils {
|
public class YoloUtils {
|
||||||
|
|
||||||
/** Essentially: just apply activation functions... */
|
/** Essentially: just apply activation functions... For NCHW format. For NCHW format, use one of the other activate methods */
|
||||||
public static INDArray activate(INDArray boundingBoxPriors, INDArray input) {
|
public static INDArray activate(INDArray boundingBoxPriors, INDArray input) {
|
||||||
return activate(boundingBoxPriors, input, LayerWorkspaceMgr.noWorkspaces());
|
return activate(boundingBoxPriors, input, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, LayerWorkspaceMgr layerWorkspaceMgr){
|
public static INDArray activate(INDArray boundingBoxPriors, INDArray input, boolean nchw) {
|
||||||
|
return activate(boundingBoxPriors, input, nchw, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, LayerWorkspaceMgr layerWorkspaceMgr) {
|
||||||
|
return activate(boundingBoxPriors, input, true, layerWorkspaceMgr);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, boolean nchw, LayerWorkspaceMgr layerWorkspaceMgr){
|
||||||
|
if(!nchw)
|
||||||
|
input = input.permute(0,3,1,2); //NHWC to NCHW
|
||||||
|
|
||||||
long mb = input.size(0);
|
long mb = input.size(0);
|
||||||
long h = input.size(2);
|
long h = input.size(2);
|
||||||
long w = input.size(3);
|
long w = input.size(3);
|
||||||
|
@ -83,6 +94,9 @@ public class YoloUtils {
|
||||||
INDArray outputClasses = output5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [minibatch, C, H, W]
|
INDArray outputClasses = output5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [minibatch, C, H, W]
|
||||||
outputClasses.assign(postSoftmax5d);
|
outputClasses.assign(postSoftmax5d);
|
||||||
|
|
||||||
|
if(!nchw)
|
||||||
|
output = output.permute(0,2,3,1); //NCHW to NHWC
|
||||||
|
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue