Op Fixes (#28)

* #8280 biasadd_bp nchw arg fixes (java side) + test

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8285 Concat op Java side fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Concat op cpp fix - allow dynamic axis to be negative, same as static axis

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* ignores for deconv3d import tests until deconv3d_tf op is implemented

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-05 00:05:04 +11:00 committed by GitHub
parent 4763547c9e
commit 948ebef41c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 91 additions and 108 deletions

View File

@ -78,7 +78,10 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) {
}
const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank);
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
if(axis < 0){
axis += rank;
}
// ******** input validation ******** //
REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
@ -150,7 +153,10 @@ DECLARE_SHAPE_FN(concat) {
const int rank = arrShapes[0][0];
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank);
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
if(axis < 0){
axis += rank;
}
// ******** input validation ******** //
REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);

View File

@ -975,8 +975,8 @@ public class DifferentialFunctionFactory {
return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable();
}
public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad) {
return new BiasAddGrad(sameDiff(), input, bias, grad).outputVariables();
public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad, boolean nchw) {
return new BiasAddGrad(sameDiff(), input, bias, grad, nchw).outputVariables();
}
public SDVariable norm1(SDVariable i_x, boolean keepDims, int... dimensions) {

View File

@ -45,12 +45,14 @@ public class BiasAdd extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[] {input, bias}, false);
bArguments.clear();
bArguments.add(nchw);
this.nchw = nchw;
}
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){
super(new INDArray[]{input, bias}, wrapOrNull(output));
bArguments.clear();
bArguments.add(nchw);
this.nchw = nchw;
}
@Override
@ -80,7 +82,7 @@ public class BiasAdd extends DynamicCustomOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> gradient){
return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0)));
return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0), nchw));
}
@Override

View File

@ -31,9 +31,12 @@ import java.util.Collections;
import java.util.List;
public class BiasAddGrad extends DynamicCustomOp {
protected boolean nchw = true;
public BiasAddGrad(SameDiff sameDiff, SDVariable input, SDVariable bias, SDVariable gradient) {
public BiasAddGrad(SameDiff sameDiff, SDVariable input, SDVariable bias, SDVariable gradient, boolean nchw) {
super(null, sameDiff, new SDVariable[]{input, bias, gradient});
this.nchw = nchw;
addBArgument(nchw);
}
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient, INDArray output){
@ -52,8 +55,6 @@ public class BiasAddGrad extends DynamicCustomOp {
return "biasadd_bp";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Differentiation not supported for op " + getClass().getSimpleName());

View File

@ -39,6 +39,7 @@ import java.util.*;
@Slf4j
public class Concat extends DynamicCustomOp {
private int concatDimension = -1;
private boolean isDynamicAxis = false;
public Concat(){
@ -83,73 +84,11 @@ public class Concat extends DynamicCustomOp {
}
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String,PropertyMapping> concatMap = new HashMap<>();
val concatDimProps = PropertyMapping.builder()
.tfInputPosition(0)
.onnxAttrName("axis")
.build();
concatMap.put("concatDimension",concatDimProps);
Map<String,PropertyMapping> concatV2Map = new HashMap<>();
val concat2DimProps = PropertyMapping.builder()
//lalst position
.tfInputPosition(-1)
.onnxAttrName("axis")
.build();
concatV2Map.put("concatDimension",concat2DimProps);
//note that onnx is already covered here
ret.put(tensorflowNames()[0],concatMap);
ret.put(tensorflowNames()[1],concatV2Map);
return ret;
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
int concatDimension = -1;
String input = null;
val inputCount = nodeDef.getInputCount();
for(int i = 0; i < inputCount; i++) {
if(nodeDef.getInput(i).contains("/concat_dim")) {
input = nodeDef.getInput(i);
break;
}
}
//older versions may specify a concat_dim, usually it's the last argument
if(input == null) {
input = nodeDef.getInput(nodeDef.getInputCount() - 1);
}
val variable = initWith.getVariable(input);
// concat dimension is only possible
if (variable != null) {
val arr = variable.getArr();
if (arr.length() == 1) {
concatDimension = arr.getInt(0);
}
this.concatDimension = concatDimension;
addIArgument(this.concatDimension);
log.trace("Concat dimension: {}", concatDimension);
}
//don't pass both iArg and last axis down to libnd4j
if(inputArguments().length == nodeDef.getInputCount()) {
val inputArgs = inputArguments();
removeInputArgument(inputArgs[inputArguments().length - 1]);
}
//TODO Fix this: https://github.com/eclipse/deeplearning4j/issues/8285
sameDiff.removeArgFromOp(input,this);
//TF uses dynamic axis - last argument is a scalar integer array for axis
addBArgument(true);
isDynamicAxis = true;
}
@Override
@ -159,12 +98,6 @@ public class Concat extends DynamicCustomOp {
return ret;
}
@Override
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}
@Override
public String onnxName() {
return "Concat";
@ -175,7 +108,6 @@ public class Concat extends DynamicCustomOp {
return "Concat";
}
@Override
public String[] tensorflowNames() {
return new String[] {"Concat","ConcatV2"};
@ -189,18 +121,32 @@ public class Concat extends DynamicCustomOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable[] args = args();
SDVariable[] bpArgs = Arrays.copyOf(args, args.length + 1);
bpArgs[bpArgs.length-1] = i_v.get(0);
SDVariable[] bpArgs;
if(isDynamicAxis){
bpArgs = Arrays.copyOf(args, args.length + 2);
bpArgs[bpArgs.length - 1] = bpArgs[bpArgs.length - 3]; //Last input is axis -> move to end of bp args too
bpArgs[bpArgs.length - 2] = i_v.get(0);
return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables());
} else {
bpArgs = Arrays.copyOf(args, args.length + 1);
bpArgs[bpArgs.length - 1] = i_v.get(0);
return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables());
}
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
DataType first = dataTypes.get(0);
for( int i=1; i<dataTypes.size(); i++ ){
for( int i=1; i<dataTypes.size() - (isDynamicAxis ? 1 : 0); i++ ){
DataType dt = dataTypes.get(i);
Preconditions.checkState(first == dt, "All inputs must have same datatype - got %s and %s for inputs 0 and %s respectively", first, dt, i);
}
if(isDynamicAxis) {
Preconditions.checkState(dataTypes.get(dataTypes.size() - 1).isIntType(),
"For dynamic axis case, last datatype must be an integer type, got input types %s");
}
//Output type is same as input types
return Collections.singletonList(first);
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.bp;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx;
@ -42,6 +43,7 @@ import java.util.*;
@Slf4j
public class ConcatBp extends DynamicCustomOp {
private int concatDimension;
private boolean dynamicAxis;
public ConcatBp(){
@ -53,38 +55,30 @@ public class ConcatBp extends DynamicCustomOp {
* @param concatDimension
* @param inputsAndGrad Original inputs, followed by output gradient
*/
public ConcatBp(SameDiff sameDiff, int concatDimension, SDVariable... inputsAndGrad){
public ConcatBp(@NonNull SameDiff sameDiff, int concatDimension, @NonNull SDVariable... inputsAndGrad){
super(null, sameDiff, inputsAndGrad);
addIArgument(concatDimension);
this.concatDimension = concatDimension;
}
/**
*
* @param sameDiff SameDiff instance
* @param inputsGradAxis Inputs, gradient array, and axis
*/
public ConcatBp(@NonNull SameDiff sameDiff, @NonNull SDVariable... inputsGradAxis){
super(null, sameDiff, inputsGradAxis);
Preconditions.checkState(inputsGradAxis[inputsGradAxis.length-1].dataType().isIntType(),
"When using this constructor, the last input must be an integer array (for the axis)");
addBArgument(true); //Last argument
this.dynamicAxis = true;
}
@Override
public String opName() {
return "concat_bp";
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
//No op
}
@Override
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
//No op
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public Op.Type opType() {
return Op.Type.CUSTOM;
@ -92,7 +86,7 @@ public class ConcatBp extends DynamicCustomOp {
@Override
public int getNumOutputs(){
return args().length - 1;
return args().length - 1 - (dynamicAxis ? 1 : 0);
}
@Override

View File

@ -1358,4 +1358,35 @@ public class LayerOpValidation extends BaseOpValidation {
.build());
assertEquals(outCC, outFC); //Fails here
}
@Test
public void testBiasAdd_nchw_nhwc() {
Nd4j.getRandom().setSeed(12345);
for(boolean nchw : new boolean[]{true, false}) {
log.info("Starting test: {}", nchw ? "nchw" : "nhwc");
SameDiff sameDiff = SameDiff.create();
SDVariable in = sameDiff.var("input", Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{2,4,3,3} : new long[]{2,3,3,4}));
SDVariable b = sameDiff.var("bias", Nd4j.rand(DataType.DOUBLE, new long[]{4}));
SDVariable bAdd = sameDiff.nn.biasAdd(in, b, nchw);
SDVariable loss = bAdd.std(true);
INDArray exp = in.getArr().dup();
if(nchw){
exp.addi(b.getArr().reshape(1,4,1,1));
} else {
exp.addi(b.getArr().reshape(1,1,1,4));
}
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(bAdd.name(), exp);
String err = OpValidation.validate(tc);
assertNull(err);
}
}
}

View File

@ -97,7 +97,10 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
"g_11",
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
"multinomial/.*"
"multinomial/.*",
//2019/11/04 AB - disabled, pending libnd4j deconv3d_tf implementation
"conv3d_transpose.*"
};
@BeforeClass