* #8901 Avoid unnecessary warning in SameDiffLoss

Signed-off-by: Alex Black <blacka101@gmail.com>

* Improved error messages for conv2d layers - NCHW vs. NHWC

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-05-05 12:24:03 +10:00 committed by GitHub
parent 7651a486e1
commit 615a48f0cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 121 additions and 20 deletions

View File

@ -18,11 +18,9 @@ package org.deeplearning4j.nn.layers.convolution;
import lombok.*; import lombok.*;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.CnnLossLayer; import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
@ -35,6 +33,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;
@ -49,6 +48,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@RunWith(Parameterized.class) @RunWith(Parameterized.class)
public class ConvDataFormatTests extends BaseDL4JTest { public class ConvDataFormatTests extends BaseDL4JTest {
@ -971,4 +971,58 @@ public class ConvDataFormatTests extends BaseDL4JTest {
return null; return null;
} }
} }
@Test
public void testWrongFormatIn(){
for(CNN2DFormat df : CNN2DFormat.values()){
for(int i=0; i<4; i++ ){
NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder()
.list();
switch (i){
case 0:
b.layer(new ConvolutionLayer.Builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build());
break;
case 1:
b.layer(new DepthwiseConvolution2D.Builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build());
break;
case 2:
b.layer(new Deconvolution2D.Builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build());
break;
case 3:
b.layer(new SeparableConvolution2D.Builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build());
break;
}
MultiLayerNetwork net = new MultiLayerNetwork(b.build());
net.init();
INDArray in;
INDArray wrongFormatIn;
if(df == CNN2DFormat.NCHW){
in = Nd4j.create(DataType.FLOAT, 5, 3, 12, 12);
wrongFormatIn = Nd4j.create(DataType.FLOAT, 5, 12, 12, 3);
} else {
in = Nd4j.create(DataType.FLOAT, 5, 12, 12, 3);
wrongFormatIn = Nd4j.create(DataType.FLOAT, 5, 3, 12, 12);
}
net.output(in);
try {
net.output(wrongFormatIn);
} catch (DL4JInvalidInputException e){
// e.printStackTrace();
String msg = e.getMessage();
assertTrue(msg, msg.contains(ConvolutionUtils.NCHW_NHWC_ERROR_MSG));
}
}
}
}
} }

View File

@ -63,6 +63,9 @@ public class Deconvolution2D extends ConvolutionLayer {
protected Deconvolution2D(BaseConvBuilder<?> builder) { protected Deconvolution2D(BaseConvBuilder<?> builder) {
super(builder); super(builder);
initializeConstraints(builder); initializeConstraints(builder);
if(builder instanceof Builder){
this.cnn2dDataFormat = ((Builder) builder).format;
}
} }
public boolean hasBias() { public boolean hasBias() {
@ -136,7 +139,7 @@ public class Deconvolution2D extends ConvolutionLayer {
private CNN2DFormat format = CNN2DFormat.NCHW; private CNN2DFormat format = CNN2DFormat.NCHW;
public Builder format(CNN2DFormat format){ public Builder dataFormat(CNN2DFormat format){
this.format = format; this.format = format;
return this; return this;
} }

View File

@ -310,11 +310,21 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
String layerName = conf.getLayer().getLayerName(); String layerName = conf.getLayer().getLayerName();
if (layerName == null) if (layerName == null)
layerName = "(not named)"; layerName = "(not named)";
throw new DL4JInvalidInputException("Cannot do forward pass in Convolution layer (layer name = " + layerName
String s = "Cannot do forward pass in Convolution layer (layer name = " + layerName
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration" + ", layer index = " + index + "): input array channels does not match CNN layer configuration"
+ " (data input channels = " + input.size(dim) + ", " + layerConf().getCnn2dDataFormat().dimensionNames() + " (data format = " + format + ", data input channels = " + input.size(dim) + ", " + layerConf().getCnn2dDataFormat().dimensionNames()
+ "=" + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") " + "=" + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
+ layerId()); + layerId();
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
if(input.size(dimIfWrongFormat) == inDepth){
//User might have passed NCHW data to a NHWC net, or vice versa?
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
}
throw new DL4JInvalidInputException(s);
} }
} }

View File

@ -190,12 +190,21 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
String layerName = conf.getLayer().getLayerName(); String layerName = conf.getLayer().getLayerName();
if (layerName == null) if (layerName == null)
layerName = "(not named)"; layerName = "(not named)";
throw new DL4JInvalidInputException("Cannot do forward pass in Deconvolution2D layer (layer name = " + layerName
String s = "Cannot do forward pass in Deconvolution2D layer (layer name = " + layerName
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration" + ", layer index = " + index + "): input array channels does not match CNN layer configuration"
+ " (data input channels = " + input.size(cDim) + ", " + " (data format = " + format + ", data input channels = " + input.size(cDim) + ", "
+ (nchw ? "[minibatch,inputDepth,height,width]" : "[minibatch,height,width,inputDepth]") + "=" + (nchw ? "[minibatch,inputDepth,height,width]" : "[minibatch,height,width,inputDepth]") + "="
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") " + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
+ layerId()); + layerId();
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
if(input.size(dimIfWrongFormat) == inDepth){
//User might have passed NCHW data to a NHWC net, or vice versa?
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
}
throw new DL4JInvalidInputException(s);
} }
int kH = (int) weights.size(2); int kH = (int) weights.size(2);
int kW = (int) weights.size(3); int kW = (int) weights.size(3);

View File

@ -183,13 +183,21 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
String layerName = conf.getLayer().getLayerName(); String layerName = conf.getLayer().getLayerName();
if (layerName == null) if (layerName == null)
layerName = "(not named)"; layerName = "(not named)";
throw new DL4JInvalidInputException("Cannot do forward pass in DepthwiseConvolution2D layer " +
String s = "Cannot do forward pass in DepthwiseConvolution2D layer " +
"(layer name = " + layerName "(layer name = " + layerName
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration" + ", layer index = " + index + "): input array channels does not match CNN layer configuration"
+ " (data input channels = " + input.size(1) + ", " + " (data format = " + format + ", data input channels = " + input.size(1) + ", "
+ (nchw ? "[minibatch,inputDepth,height,width]=" : "[minibatch,height,width,inputDepth]=") + (nchw ? "[minibatch,inputDepth,height,width]=" : "[minibatch,height,width,inputDepth]=")
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") " + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
+ layerId()); + layerId();
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
if(input.size(dimIfWrongFormat) == inDepth){
//User might have passed NCHW data to a NHWC net, or vice versa?
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
}
throw new DL4JInvalidInputException(s);
} }
int kH = (int) depthWiseWeights.size(0); int kH = (int) depthWiseWeights.size(0);
int kW = (int) depthWiseWeights.size(1); int kW = (int) depthWiseWeights.size(1);

View File

@ -211,11 +211,20 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
String layerName = conf.getLayer().getLayerName(); String layerName = conf.getLayer().getLayerName();
if (layerName == null) if (layerName == null)
layerName = "(not named)"; layerName = "(not named)";
throw new DL4JInvalidInputException("Cannot do forward pass in SeparableConvolution2D layer (layer name = " + layerName
String s = "Cannot do forward pass in SeparableConvolution2D layer (layer name = " + layerName
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration" + ", layer index = " + index + "): input array channels does not match CNN layer configuration"
+ " (data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]=" + " (data format = " + format + ", data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]="
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") " + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
+ layerId()); + layerId();
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
if(input.size(dimIfWrongFormat) == inDepth){
//User might have passed NCHW data to a NHWC net, or vice versa?
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
}
throw new DL4JInvalidInputException(s);
} }
int kH = (int) depthWiseWeights.size(2); int kH = (int) depthWiseWeights.size(2);
int kW = (int) depthWiseWeights.size(3); int kW = (int) depthWiseWeights.size(3);

View File

@ -48,6 +48,13 @@ import java.util.Arrays;
*/ */
public class ConvolutionUtils { public class ConvolutionUtils {
public static final String NCHW_NHWC_ERROR_MSG = "Note: Convolution layers can be configured for either NCHW (channels first)" +
" or NHWC (channels last) format for input images and activations.\n" +
"Layers can be configured using .dataFormat(CNN2DFormat.NCHW/NHWC) when constructing the layer, or for the entire net using" +
" .setInputType(InputType.convolutional(height, width, depth, CNN2DForman.NCHW/NHWC)).\n" +
"ImageRecordReader and NativeImageLoader can also be configured to load image data in either NCHW or NHWC format which must match the network";
private static final int[] ONES = new int[]{1, 1}; private static final int[] ONES = new int[]{1, 1};

View File

@ -38,7 +38,7 @@ import java.util.Map;
*/ */
public abstract class SameDiffLoss implements ILossFunction { public abstract class SameDiffLoss implements ILossFunction {
protected transient SameDiff sd; protected transient SameDiff sd;
protected transient SDVariable scoreVariable; protected transient SDVariable scorePerExampleVariable;
protected SameDiffLoss() { protected SameDiffLoss() {
@ -60,7 +60,8 @@ public abstract class SameDiffLoss implements ILossFunction {
sd = SameDiff.create(); sd = SameDiff.create();
SDVariable layerInput = sd.placeHolder("layerInput", dataType, -1); SDVariable layerInput = sd.placeHolder("layerInput", dataType, -1);
SDVariable labels = sd.placeHolder("labels", dataType, -1); SDVariable labels = sd.placeHolder("labels", dataType, -1);
scoreVariable = this.defineLoss(sd, layerInput, labels); scorePerExampleVariable = this.defineLoss(sd, layerInput, labels);
scorePerExampleVariable.markAsLoss();
sd.createGradFunction("layerInput"); sd.createGradFunction("layerInput");
} }
@ -112,7 +113,7 @@ public abstract class SameDiffLoss implements ILossFunction {
m.put("labels", labels); m.put("labels", labels);
m.put("layerInput", output); m.put("layerInput", output);
INDArray scoreArr = sd.outputSingle(m,scoreVariable.name()); INDArray scoreArr = sd.outputSingle(m, scorePerExampleVariable.name());
if (mask != null) { if (mask != null) {
LossUtil.applyMask(scoreArr, mask); LossUtil.applyMask(scoreArr, mask);