diff --git a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp index 15d5d782c..2003eef3f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp @@ -78,7 +78,10 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) { } const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array - int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank); + int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) : INT_ARG(0); + if(axis < 0){ + axis += rank; + } // ******** input validation ******** // REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !"); @@ -150,7 +153,10 @@ DECLARE_SHAPE_FN(concat) { const int rank = arrShapes[0][0]; - int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank); + int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) : INT_ARG(0); + if(axis < 0){ + axis += rank; + } // ******** input validation ******** // REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 0bc395803..21559d7f4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -975,8 +975,8 @@ public class DifferentialFunctionFactory { return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable(); } - public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad) { - return new BiasAddGrad(sameDiff(), input, bias, grad).outputVariables(); + public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad, boolean nchw) { + return new BiasAddGrad(sameDiff(), input, bias, grad, nchw).outputVariables(); } public SDVariable norm1(SDVariable i_x, boolean keepDims, int... dimensions) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java index d8bf3f695..7d5dbf4fc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java @@ -45,12 +45,14 @@ public class BiasAdd extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {input, bias}, false); bArguments.clear(); bArguments.add(nchw); + this.nchw = nchw; } public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){ super(new INDArray[]{input, bias}, wrapOrNull(output)); bArguments.clear(); bArguments.add(nchw); + this.nchw = nchw; } @Override @@ -80,7 +82,7 @@ public class BiasAdd extends DynamicCustomOp { @Override public List doDiff(List gradient){ - return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0))); + return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0), nchw)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java index 0d6ced083..d3007427d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java @@ -31,9 +31,12 @@ import java.util.Collections; import java.util.List; public class BiasAddGrad extends DynamicCustomOp { + protected boolean nchw = true; - public BiasAddGrad(SameDiff sameDiff, SDVariable input, SDVariable bias, SDVariable gradient) { + public BiasAddGrad(SameDiff sameDiff, SDVariable input, SDVariable bias, SDVariable gradient, boolean nchw) { super(null, sameDiff, new SDVariable[]{input, bias, gradient}); + this.nchw = nchw; + addBArgument(nchw); } public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient, INDArray output){ @@ -52,8 +55,6 @@ public class BiasAddGrad extends DynamicCustomOp { return "biasadd_bp"; } - - @Override public List doDiff(List f1) { throw new UnsupportedOperationException("Differentiation not supported for op " + getClass().getSimpleName()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java index c860152ca..0e314662a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java @@ -39,6 +39,7 @@ import java.util.*; @Slf4j public class Concat extends DynamicCustomOp { private int concatDimension = -1; + private boolean isDynamicAxis = false; public Concat(){ @@ -83,73 +84,11 @@ public class Concat extends DynamicCustomOp { } - @Override - public Map> mappingsForFunction() { - Map> ret = new HashMap<>(); - - Map concatMap = new HashMap<>(); - val concatDimProps = PropertyMapping.builder() - .tfInputPosition(0) - .onnxAttrName("axis") - .build(); - concatMap.put("concatDimension",concatDimProps); - - - Map concatV2Map = new HashMap<>(); - val concat2DimProps = PropertyMapping.builder() - //lalst position - .tfInputPosition(-1) - .onnxAttrName("axis") - .build(); - concatV2Map.put("concatDimension",concat2DimProps); - - //note that onnx is already covered here - ret.put(tensorflowNames()[0],concatMap); - ret.put(tensorflowNames()[1],concatV2Map); - - - return ret; - } - @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - int concatDimension = -1; - String input = null; - val inputCount = nodeDef.getInputCount(); - for(int i = 0; i < inputCount; i++) { - if(nodeDef.getInput(i).contains("/concat_dim")) { - input = nodeDef.getInput(i); - break; - } - } - - //older versions may specify a concat_dim, usually it's the last argument - if(input == null) { - input = nodeDef.getInput(nodeDef.getInputCount() - 1); - } - - val variable = initWith.getVariable(input); - // concat dimension is only possible - if (variable != null) { - val arr = variable.getArr(); - if (arr.length() == 1) { - concatDimension = arr.getInt(0); - } - - this.concatDimension = concatDimension; - addIArgument(this.concatDimension); - log.trace("Concat dimension: {}", concatDimension); - - } - - //don't pass both iArg and last axis down to libnd4j - if(inputArguments().length == nodeDef.getInputCount()) { - val inputArgs = inputArguments(); - removeInputArgument(inputArgs[inputArguments().length - 1]); - } - - //TODO Fix this: https://github.com/eclipse/deeplearning4j/issues/8285 - sameDiff.removeArgFromOp(input,this); + //TF uses dynamic axis - last argument is a scalar integer array for axis + addBArgument(true); + isDynamicAxis = true; } @Override @@ -159,12 +98,6 @@ public class Concat extends DynamicCustomOp { return ret; } - - @Override - public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - super.initFromOnnx(node, initWith, attributesForNode, graph); - } - @Override public String onnxName() { return "Concat"; @@ -175,7 +108,6 @@ public class Concat extends DynamicCustomOp { return "Concat"; } - @Override public String[] tensorflowNames() { return new String[] {"Concat","ConcatV2"}; @@ -189,18 +121,32 @@ public class Concat extends DynamicCustomOp { @Override public List doDiff(List i_v) { SDVariable[] args = args(); - SDVariable[] bpArgs = Arrays.copyOf(args, args.length + 1); - bpArgs[bpArgs.length-1] = i_v.get(0); - return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables()); + SDVariable[] bpArgs; + if(isDynamicAxis){ + bpArgs = Arrays.copyOf(args, args.length + 2); + bpArgs[bpArgs.length - 1] = bpArgs[bpArgs.length - 3]; //Last input is axis -> move to end of bp args too + bpArgs[bpArgs.length - 2] = i_v.get(0); + return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables()); + } else { + bpArgs = Arrays.copyOf(args, args.length + 1); + bpArgs[bpArgs.length - 1] = i_v.get(0); + return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables()); + } } @Override public List calculateOutputDataTypes(List dataTypes){ DataType first = dataTypes.get(0); - for( int i=1; i attributesForNode, GraphDef graph) { - //No op - } - - - @Override - public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - //No op - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - @Override public Op.Type opType() { return Op.Type.CUSTOM; @@ -92,7 +86,7 @@ public class ConcatBp extends DynamicCustomOp { @Override public int getNumOutputs(){ - return args().length - 1; + return args().length - 1 - (dynamicAxis ? 1 : 0); } @Override diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index b84d7ceea..92bcc71f9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -1358,4 +1358,35 @@ public class LayerOpValidation extends BaseOpValidation { .build()); assertEquals(outCC, outFC); //Fails here } + + @Test + public void testBiasAdd_nchw_nhwc() { + Nd4j.getRandom().setSeed(12345); + + for(boolean nchw : new boolean[]{true, false}) { + log.info("Starting test: {}", nchw ? "nchw" : "nhwc"); + SameDiff sameDiff = SameDiff.create(); + + SDVariable in = sameDiff.var("input", Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{2,4,3,3} : new long[]{2,3,3,4})); + SDVariable b = sameDiff.var("bias", Nd4j.rand(DataType.DOUBLE, new long[]{4})); + + SDVariable bAdd = sameDiff.nn.biasAdd(in, b, nchw); + SDVariable loss = bAdd.std(true); + + + INDArray exp = in.getArr().dup(); + if(nchw){ + exp.addi(b.getArr().reshape(1,4,1,1)); + } else { + exp.addi(b.getArr().reshape(1,1,1,4)); + } + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(bAdd.name(), exp); + + String err = OpValidation.validate(tc); + assertNull(err); + } + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index fb420ac4b..0d25f63d4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -97,7 +97,10 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a "g_11", //2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913 - "multinomial/.*" + "multinomial/.*", + + //2019/11/04 AB - disabled, pending libnd4j deconv3d_tf implementation + "conv3d_transpose.*" }; @BeforeClass