DL4J/DataVec: Fix Yolo2OutputLayer and ObjectDetectionRecordReader support for NHWC data format (#483)
* Fix Yolo2OutputLayer for NHWC data format Signed-off-by: Alex Black <blacka101@gmail.com> * ObjectDetectionRecordReader NHWC support Signed-off-by: Alex Black <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									45ebd4899c
								
							
						
					
					
						commit
						ee3e059b12
					
				| @ -49,7 +49,7 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point; | ||||
| /** | ||||
|  * An image record reader for object detection. | ||||
|  * <p> | ||||
|  * Format of returned values: 4d array, with dimensions [minibatch, 4+C, h, w] | ||||
|  * Format of returned values: 4d array, with dimensions [minibatch, 4+C, h, w] (nchw) or [minibatch, h, w, 4+C] (nhwc) | ||||
|  * Where the image is quantized into h x w grid locations. | ||||
|  * <p> | ||||
|  * Note that this matches the format required for Deeplearning4j's Yolo2OutputLayer | ||||
| @ -61,26 +61,48 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader { | ||||
|     private final int gridW; | ||||
|     private final int gridH; | ||||
|     private final ImageObjectLabelProvider labelProvider; | ||||
|     private final boolean nchw; | ||||
| 
 | ||||
|     protected Image currentImage; | ||||
| 
 | ||||
|     /** | ||||
|      * As per {@link #ObjectDetectionRecordReader(int, int, int, int, int, boolean, ImageObjectLabelProvider)} but hardcoded | ||||
|      * to NCHW format | ||||
|      */ | ||||
|     public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, ImageObjectLabelProvider labelProvider) { | ||||
|         this(height, width, channels, gridH, gridW, true, labelProvider); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Create ObjectDetectionRecordReader with | ||||
|      * | ||||
|      * @param height        Height of the output images | ||||
|      * @param width         Width of the output images | ||||
|      * @param channels      Number of channels for the output images | ||||
|      * @param gridH         Grid/quantization size (along  height dimension) - Y axis | ||||
|      * @param gridW         Grid/quantization size (along  height dimension) - X axis | ||||
|      * @param nchw          If true: return NCHW format labels with array shape [minibatch, 4+C, h, w]; if false, return | ||||
|      *                      NHWC format labels with array shape [minibatch, h, w, 4+C] | ||||
|      * @param labelProvider ImageObjectLabelProvider - used to look up which objects are in each image | ||||
|      */ | ||||
|     public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, ImageObjectLabelProvider labelProvider) { | ||||
|     public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, boolean nchw, ImageObjectLabelProvider labelProvider) { | ||||
|         super(height, width, channels, null, null); | ||||
|         this.gridW = gridW; | ||||
|         this.gridH = gridH; | ||||
|         this.nchw = nchw; | ||||
|         this.labelProvider = labelProvider; | ||||
|         this.appendLabel = labelProvider != null; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * As per {@link #ObjectDetectionRecordReader(int, int, int, int, int, boolean, ImageObjectLabelProvider, ImageTransform)} | ||||
|      * but hardcoded to NCHW format | ||||
|      */ | ||||
|     public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, | ||||
|                                        ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) { | ||||
|         this(height, width, channels, gridH, gridW, true, labelProvider, imageTransform); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * When imageTransform != null, object is removed if new center is outside of transformed image bounds. | ||||
|      * | ||||
| @ -90,13 +112,16 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader { | ||||
|      * @param gridH          Grid/quantization size (along  height dimension) - Y axis | ||||
|      * @param gridW          Grid/quantization size (along  height dimension) - X axis | ||||
|      * @param labelProvider  ImageObjectLabelProvider - used to look up which objects are in each image | ||||
|      * @param nchw           If true: return NCHW format labels with array shape [minibatch, 4+C, h, w]; if false, return | ||||
|      *                       NHWC format labels with array shape [minibatch, h, w, 4+C] | ||||
|      * @param imageTransform ImageTransform - used to transform image and coordinates | ||||
|      */ | ||||
|     public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, | ||||
|     public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, boolean nchw, | ||||
|                                        ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) { | ||||
|         super(height, width, channels, null, null); | ||||
|         this.gridW = gridW; | ||||
|         this.gridH = gridH; | ||||
|         this.nchw = nchw; | ||||
|         this.labelProvider = labelProvider; | ||||
|         this.appendLabel = labelProvider != null; | ||||
|         this.imageTransform = imageTransform; | ||||
| @ -182,6 +207,10 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader { | ||||
|             exampleNum++; | ||||
|         } | ||||
| 
 | ||||
|         if(!nchw) { | ||||
|             outImg = outImg.permute(0, 2, 3, 1);        //NCHW to NHWC | ||||
|             outLabel = outLabel.permute(0, 2, 3, 1); | ||||
|         } | ||||
|         return new NDArrayRecordBatch(Arrays.asList(outImg, outLabel)); | ||||
|     } | ||||
| 
 | ||||
| @ -256,6 +285,8 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader { | ||||
|             imageLoader = new NativeImageLoader(height, width, channels, imageTransform); | ||||
|         } | ||||
|         Image image = this.imageLoader.asImageMatrix(dataInputStream); | ||||
|         if(!nchw) | ||||
|             image.setImage(image.getImage().permute(0,2,3,1)); | ||||
|         Nd4j.getAffinityManager().ensureLocation(image.getImage(), AffinityManager.Location.DEVICE); | ||||
| 
 | ||||
|         List<Writable> ret = RecordConverter.toRecord(image.getImage()); | ||||
| @ -264,6 +295,8 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader { | ||||
|             int nClasses = labels.size(); | ||||
|             INDArray outLabel = Nd4j.create(1, 4 + nClasses, gridH, gridW); | ||||
|             label(image, imageObjectsForPath, outLabel, 0); | ||||
|             if(!nchw) | ||||
|                 outLabel = outLabel.permute(0,2,3,1);   //NCHW to NHWC | ||||
|             ret.add(new NDArrayWritable(outLabel)); | ||||
|         } | ||||
|         return ret; | ||||
|  | ||||
| @ -56,6 +56,7 @@ public class TestObjectDetectionRecordReader { | ||||
| 
 | ||||
|     @Test | ||||
|     public void test() throws Exception { | ||||
|         for(boolean nchw : new boolean[]{true, false}) { | ||||
|             ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider(); | ||||
| 
 | ||||
|             File f = testDir.newFolder(); | ||||
| @ -73,10 +74,10 @@ public class TestObjectDetectionRecordReader { | ||||
|             URI[] u = new FileSplit(new File(path)).locations(); | ||||
|             Arrays.sort(u); | ||||
| 
 | ||||
|         RecordReader rr = new ObjectDetectionRecordReader(h, w, c, gH, gW, lp); | ||||
|             RecordReader rr = new ObjectDetectionRecordReader(h, w, c, gH, gW, nchw, lp); | ||||
|             rr.initialize(new CollectionInputSplit(u)); | ||||
| 
 | ||||
|         RecordReader imgRR = new ImageRecordReader(h, w, c); | ||||
|             RecordReader imgRR = new ImageRecordReader(h, w, c, nchw); | ||||
|             imgRR.initialize(new CollectionInputSplit(u)); | ||||
| 
 | ||||
|             List<String> labels = rr.getLabels(); | ||||
| @ -135,24 +136,32 @@ public class TestObjectDetectionRecordReader { | ||||
|                 } | ||||
| 
 | ||||
|                 INDArray lArr = ((NDArrayWritable) next.get(1)).get(); | ||||
|                 if(nchw) { | ||||
|                     assertArrayEquals(new long[]{1, 4 + 2, gH, gW}, lArr.shape()); | ||||
|                 } else { | ||||
|                     assertArrayEquals(new long[]{1, gH, gW, 4 + 2}, lArr.shape()); | ||||
|                 } | ||||
| 
 | ||||
|                 if(!nchw) | ||||
|                     expLabels = expLabels.permute(0,2,3,1); //NCHW to NHWC | ||||
| 
 | ||||
|                 assertEquals(expLabels, lArr); | ||||
|             } | ||||
| 
 | ||||
|             rr.reset(); | ||||
|             Record record = rr.nextRecord(); | ||||
|         RecordMetaDataImageURI metadata = (RecordMetaDataImageURI)record.getMetaData(); | ||||
|             RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) record.getMetaData(); | ||||
|             assertEquals(new File(path, "000012.jpg"), new File(metadata.getURI())); | ||||
|             assertEquals(3, metadata.getOrigC()); | ||||
|         assertEquals((int)origH[0], metadata.getOrigH()); | ||||
|         assertEquals((int)origW[0], metadata.getOrigW()); | ||||
|             assertEquals((int) origH[0], metadata.getOrigH()); | ||||
|             assertEquals((int) origW[0], metadata.getOrigW()); | ||||
| 
 | ||||
|             List<Record> out = new ArrayList<>(); | ||||
|             List<RecordMetaData> meta = new ArrayList<>(); | ||||
|             out.add(record); | ||||
|             meta.add(metadata); | ||||
|             record = rr.nextRecord(); | ||||
|         metadata = (RecordMetaDataImageURI)record.getMetaData(); | ||||
|             metadata = (RecordMetaDataImageURI) record.getMetaData(); | ||||
|             out.add(record); | ||||
|             meta.add(metadata); | ||||
| 
 | ||||
| @ -164,27 +173,27 @@ public class TestObjectDetectionRecordReader { | ||||
|             int[] nonzeroCount = {5, 10}; | ||||
| 
 | ||||
|             ImageTransform transform = new ResizeImageTransform(37, 42); | ||||
|         RecordReader rrTransform = new ObjectDetectionRecordReader(42, 37, c, gH, gW, lp, transform); | ||||
|             RecordReader rrTransform = new ObjectDetectionRecordReader(42, 37, c, gH, gW, nchw, lp, transform); | ||||
|             rrTransform.initialize(new CollectionInputSplit(u)); | ||||
|             i = 0; | ||||
|             while (rrTransform.hasNext()) { | ||||
|                 List<Writable> next = rrTransform.next(); | ||||
|                 assertEquals(37, transform.getCurrentImage().getWidth()); | ||||
|                 assertEquals(42, transform.getCurrentImage().getHeight()); | ||||
|             INDArray labelArray = ((NDArrayWritable)next.get(1)).get(); | ||||
|                 INDArray labelArray = ((NDArrayWritable) next.get(1)).get(); | ||||
|                 BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); | ||||
|                 assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); | ||||
|             } | ||||
| 
 | ||||
|             ImageTransform transform2 = new ResizeImageTransform(1024, 2048); | ||||
|         RecordReader rrTransform2 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, lp, transform2); | ||||
|             RecordReader rrTransform2 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, nchw, lp, transform2); | ||||
|             rrTransform2.initialize(new CollectionInputSplit(u)); | ||||
|             i = 0; | ||||
|             while (rrTransform2.hasNext()) { | ||||
|                 List<Writable> next = rrTransform2.next(); | ||||
|                 assertEquals(1024, transform2.getCurrentImage().getWidth()); | ||||
|                 assertEquals(2048, transform2.getCurrentImage().getHeight()); | ||||
|             INDArray labelArray = ((NDArrayWritable)next.get(1)).get(); | ||||
|                 INDArray labelArray = ((NDArrayWritable) next.get(1)).get(); | ||||
|                 BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); | ||||
|                 assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); | ||||
|             } | ||||
| @ -194,19 +203,19 @@ public class TestObjectDetectionRecordReader { | ||||
|                     new ResizeImageTransform(2048, 4096), | ||||
|                     new FlipImageTransform(-1) | ||||
|             ); | ||||
|         RecordReader rrTransform3 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, lp, transform3); | ||||
|             RecordReader rrTransform3 = new ObjectDetectionRecordReader(2048, 1024, c, gH, gW, nchw, lp, transform3); | ||||
|             rrTransform3.initialize(new CollectionInputSplit(u)); | ||||
|             i = 0; | ||||
|             while (rrTransform3.hasNext()) { | ||||
|                 List<Writable> next = rrTransform3.next(); | ||||
|             INDArray labelArray = ((NDArrayWritable)next.get(1)).get(); | ||||
|                 INDArray labelArray = ((NDArrayWritable) next.get(1)).get(); | ||||
|                 BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); | ||||
|                 assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); | ||||
|             } | ||||
| 
 | ||||
|             //Test that doing a downscale with the native image loader directly instead of a transform does not cause an exception: | ||||
|             ImageTransform transform4 = new FlipImageTransform(-1); | ||||
|         RecordReader rrTransform4 = new ObjectDetectionRecordReader(128, 128, c, gH, gW, lp, transform4); | ||||
|             RecordReader rrTransform4 = new ObjectDetectionRecordReader(128, 128, c, gH, gW, nchw, lp, transform4); | ||||
|             rrTransform4.initialize(new CollectionInputSplit(u)); | ||||
|             i = 0; | ||||
|             while (rrTransform4.hasNext()) { | ||||
| @ -215,10 +224,12 @@ public class TestObjectDetectionRecordReader { | ||||
|                 assertEquals((int) origW[i], transform4.getCurrentImage().getWidth()); | ||||
|                 assertEquals((int) origH[i], transform4.getCurrentImage().getHeight()); | ||||
| 
 | ||||
|             INDArray labelArray = ((NDArrayWritable)next.get(1)).get(); | ||||
|                 INDArray labelArray = ((NDArrayWritable) next.get(1)).get(); | ||||
|                 BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); | ||||
|                 assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); | ||||
|             } | ||||
| 
 | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     //2 images: 000012.jpg and 000019.jpg | ||||
|  | ||||
| @ -24,9 +24,7 @@ import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider; | ||||
| import org.deeplearning4j.BaseDL4JTest; | ||||
| import org.deeplearning4j.TestUtils; | ||||
| import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | ||||
| import org.deeplearning4j.nn.conf.ConvolutionMode; | ||||
| import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | ||||
| import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||||
| import org.deeplearning4j.nn.conf.*; | ||||
| import org.deeplearning4j.nn.conf.distribution.GaussianDistribution; | ||||
| import org.deeplearning4j.nn.conf.inputs.InputType; | ||||
| import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; | ||||
| @ -36,6 +34,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||||
| import org.junit.Rule; | ||||
| import org.junit.Test; | ||||
| import org.junit.rules.TemporaryFolder; | ||||
| import org.junit.runner.RunWith; | ||||
| import org.junit.runners.Parameterized; | ||||
| import org.nd4j.linalg.activations.Activation; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| @ -50,17 +50,28 @@ import java.io.File; | ||||
| import java.io.FileOutputStream; | ||||
| import java.io.InputStream; | ||||
| 
 | ||||
| import static org.junit.Assert.assertArrayEquals; | ||||
| import static org.junit.Assert.assertTrue; | ||||
| 
 | ||||
| /** | ||||
|  * @author Alex Black | ||||
|  */ | ||||
| @RunWith(Parameterized.class) | ||||
| public class YoloGradientCheckTests extends BaseDL4JTest { | ||||
| 
 | ||||
|     static { | ||||
|         Nd4j.setDataType(DataType.DOUBLE); | ||||
|     } | ||||
| 
 | ||||
|     private CNN2DFormat format; | ||||
|     public YoloGradientCheckTests(CNN2DFormat format){ | ||||
|         this.format = format; | ||||
|     } | ||||
|     @Parameterized.Parameters(name = "{0}") | ||||
|     public static Object[] params(){ | ||||
|         return CNN2DFormat.values(); | ||||
|     } | ||||
| 
 | ||||
|     @Rule | ||||
|     public TemporaryFolder testDir = new TemporaryFolder(); | ||||
| 
 | ||||
| @ -97,8 +108,14 @@ public class YoloGradientCheckTests extends BaseDL4JTest { | ||||
| 
 | ||||
|             Nd4j.getRandom().setSeed(12345); | ||||
| 
 | ||||
|             INDArray input = Nd4j.rand(new int[]{mb, depthIn, h, w}); | ||||
|             INDArray labels = yoloLabels(mb, c, h, w); | ||||
|             INDArray input, labels; | ||||
|             if(format == CNN2DFormat.NCHW){ | ||||
|                 input = Nd4j.rand(DataType.DOUBLE, mb, depthIn, h, w); | ||||
|                 labels = yoloLabels(mb, c, h, w); | ||||
|             } else { | ||||
|                 input = Nd4j.rand(DataType.DOUBLE, mb, h, w, depthIn); | ||||
|                 labels = yoloLabels(mb, c, h, w).permute(0,2,3,1); | ||||
|             } | ||||
| 
 | ||||
|             MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) | ||||
|                     .dataType(DataType.DOUBLE) | ||||
| @ -112,6 +129,7 @@ public class YoloGradientCheckTests extends BaseDL4JTest { | ||||
|                     .layer(new Yolo2OutputLayer.Builder() | ||||
|                             .boundingBoxPriors(bbPrior) | ||||
|                             .build()) | ||||
|                     .setInputType(InputType.convolutional(h, w, depthIn, format)) | ||||
|                     .build(); | ||||
| 
 | ||||
|             MultiLayerNetwork net = new MultiLayerNetwork(conf); | ||||
| @ -120,7 +138,18 @@ public class YoloGradientCheckTests extends BaseDL4JTest { | ||||
|             String msg = "testYoloOutputLayer() - minibatch = " + mb + ", w=" + w + ", h=" + h + ", l1=" + l1[i] + ", l2=" + l2[i]; | ||||
|             System.out.println(msg); | ||||
| 
 | ||||
|             INDArray out = net.output(input); | ||||
|             if(format == CNN2DFormat.NCHW){ | ||||
|                 assertArrayEquals(new long[]{mb, yoloDepth, h, w}, out.shape()); | ||||
|             } else { | ||||
|                 assertArrayEquals(new long[]{mb, h, w, yoloDepth}, out.shape()); | ||||
|             } | ||||
| 
 | ||||
|             net.fit(input, labels); | ||||
| 
 | ||||
| 
 | ||||
|             boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) | ||||
|                     .minAbsoluteError(1e-6) | ||||
|                     .labels(labels).subset(true).maxPerParam(100)); | ||||
| 
 | ||||
|             assertTrue(msg, gradOK); | ||||
|  | ||||
| @ -21,6 +21,7 @@ import lombok.Getter; | ||||
| import lombok.Setter; | ||||
| import org.deeplearning4j.nn.api.Layer; | ||||
| import org.deeplearning4j.nn.api.ParamInitializer; | ||||
| import org.deeplearning4j.nn.conf.CNN2DFormat; | ||||
| import org.deeplearning4j.nn.conf.GradientNormalization; | ||||
| import org.deeplearning4j.nn.conf.InputPreProcessor; | ||||
| import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||||
| @ -80,6 +81,8 @@ public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer { | ||||
|     @JsonDeserialize(using = BoundingBoxesDeserializer.class) | ||||
|     private INDArray boundingBoxes; | ||||
| 
 | ||||
|     private CNN2DFormat format = CNN2DFormat.NCHW;  //Default for serialization of old formats | ||||
| 
 | ||||
|     private Yolo2OutputLayer() { | ||||
|         //No-arg constructor for Jackson JSON | ||||
|     } | ||||
| @ -119,7 +122,8 @@ public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer { | ||||
| 
 | ||||
|     @Override | ||||
|     public void setNIn(InputType inputType, boolean override) { | ||||
|         //No op | ||||
|         InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; | ||||
|         this.format = c.getFormat(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|  | ||||
| @ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.objdetect; | ||||
| import lombok.*; | ||||
| import org.deeplearning4j.nn.api.Layer; | ||||
| import org.deeplearning4j.nn.api.layers.IOutputLayer; | ||||
| import org.deeplearning4j.nn.conf.CNN2DFormat; | ||||
| import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||||
| import org.deeplearning4j.nn.gradient.DefaultGradient; | ||||
| import org.deeplearning4j.nn.gradient.Gradient; | ||||
| @ -110,6 +111,12 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l | ||||
|         Preconditions.checkState(labels.rank() == 4, "Expected rank 4 labels array with shape [minibatch, 4+numClasses, h, w]" + | ||||
|                 " but got rank %s labels array with shape %s", labels.rank(), labels.shape()); | ||||
| 
 | ||||
|         boolean nchw = layerConf().getFormat() == CNN2DFormat.NCHW; | ||||
|         INDArray input = nchw ? this.input : this.input.permute(0,3,1,2);   //NHWC to NCHW | ||||
|         INDArray labels = this.labels.castTo(input.dataType());     //Ensure correct dtype (same as params); no-op if already correct dtype | ||||
|         if(!nchw) | ||||
|             labels = labels.permute(0,3,1,2);   //NHWC to NCHW | ||||
| 
 | ||||
|         double lambdaCoord = layerConf().getLambdaCoord(); | ||||
|         double lambdaNoObj = layerConf().getLambdaNoObj(); | ||||
| 
 | ||||
| @ -119,7 +126,7 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l | ||||
|         int b = (int) layerConf().getBoundingBoxes().size(0); | ||||
|         int c = (int) labels.size(1)-4; | ||||
| 
 | ||||
|         INDArray labels = this.labels.castTo(input.dataType());     //Ensure correct dtype (same as params); no-op if already correct dtype | ||||
| 
 | ||||
| 
 | ||||
|         //Various shape arrays, to reuse | ||||
|         long[] nhw = new long[]{mb, h, w}; | ||||
| @ -380,13 +387,17 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l | ||||
|         epsWH.addi(dLc_din_wh); | ||||
|         epsXY.addi(dLc_din_xy); | ||||
| 
 | ||||
|         if(!nchw) | ||||
|             epsOut = epsOut.permute(0,2,3,1);   //NCHW to NHWC | ||||
| 
 | ||||
|         return epsOut; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { | ||||
|         assertInputSet(false); | ||||
|         return YoloUtils.activate(layerConf().getBoundingBoxes(), input, workspaceMgr); | ||||
|         boolean nchw = layerConf().getFormat() == CNN2DFormat.NCHW; | ||||
|         return YoloUtils.activate(layerConf().getBoundingBoxes(), input, nchw, workspaceMgr); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|  | ||||
| @ -39,12 +39,23 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*; | ||||
|  */ | ||||
| public class YoloUtils { | ||||
| 
 | ||||
|     /** Essentially: just apply activation functions... */ | ||||
|     /** Essentially: just apply activation functions... For NCHW format. For NCHW format, use one of the other activate methods */ | ||||
|     public static INDArray activate(INDArray boundingBoxPriors, INDArray input) { | ||||
|         return activate(boundingBoxPriors, input, LayerWorkspaceMgr.noWorkspaces()); | ||||
|         return activate(boundingBoxPriors, input, true); | ||||
|     } | ||||
| 
 | ||||
|     public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, LayerWorkspaceMgr layerWorkspaceMgr){ | ||||
|     public static INDArray activate(INDArray boundingBoxPriors, INDArray input, boolean nchw) { | ||||
|         return activate(boundingBoxPriors, input, nchw, LayerWorkspaceMgr.noWorkspaces()); | ||||
|     } | ||||
| 
 | ||||
|     public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, LayerWorkspaceMgr layerWorkspaceMgr) { | ||||
|         return activate(boundingBoxPriors, input, true, layerWorkspaceMgr); | ||||
|     } | ||||
| 
 | ||||
|     public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, boolean nchw, LayerWorkspaceMgr layerWorkspaceMgr){ | ||||
|         if(!nchw) | ||||
|             input = input.permute(0,3,1,2); //NHWC to NCHW | ||||
| 
 | ||||
|         long mb = input.size(0); | ||||
|         long h = input.size(2); | ||||
|         long w = input.size(3); | ||||
| @ -83,6 +94,9 @@ public class YoloUtils { | ||||
|         INDArray outputClasses = output5.get(all(), all(), interval(5, 5+c), all(), all());   //Shape: [minibatch, C, H, W] | ||||
|         outputClasses.assign(postSoftmax5d); | ||||
| 
 | ||||
|         if(!nchw) | ||||
|             output = output.permute(0,2,3,1);       //NCHW to NHWC | ||||
| 
 | ||||
|         return output; | ||||
|     } | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user