Op Fixes (#28)
* #8280 biasadd_bp nchw arg fixes (java side) + test Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8285 Concat op Java side fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Concat op cpp fix - allow dynamic axis to be negative, same as static axis Signed-off-by: AlexDBlack <blacka101@gmail.com> * ignores for deconv3d import tests until deconv3d_tf op is implemented Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
4763547c9e
commit
948ebef41c
|
@ -78,7 +78,10 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) {
|
|||
}
|
||||
|
||||
const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array
|
||||
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank);
|
||||
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
|
||||
if(axis < 0){
|
||||
axis += rank;
|
||||
}
|
||||
|
||||
// ******** input validation ******** //
|
||||
REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
|
||||
|
@ -150,7 +153,10 @@ DECLARE_SHAPE_FN(concat) {
|
|||
|
||||
const int rank = arrShapes[0][0];
|
||||
|
||||
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank);
|
||||
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
|
||||
if(axis < 0){
|
||||
axis += rank;
|
||||
}
|
||||
|
||||
// ******** input validation ******** //
|
||||
REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
||||
|
|
|
@ -975,8 +975,8 @@ public class DifferentialFunctionFactory {
|
|||
return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad) {
|
||||
return new BiasAddGrad(sameDiff(), input, bias, grad).outputVariables();
|
||||
public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad, boolean nchw) {
|
||||
return new BiasAddGrad(sameDiff(), input, bias, grad, nchw).outputVariables();
|
||||
}
|
||||
|
||||
public SDVariable norm1(SDVariable i_x, boolean keepDims, int... dimensions) {
|
||||
|
|
|
@ -45,12 +45,14 @@ public class BiasAdd extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[] {input, bias}, false);
|
||||
bArguments.clear();
|
||||
bArguments.add(nchw);
|
||||
this.nchw = nchw;
|
||||
}
|
||||
|
||||
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){
|
||||
super(new INDArray[]{input, bias}, wrapOrNull(output));
|
||||
bArguments.clear();
|
||||
bArguments.add(nchw);
|
||||
this.nchw = nchw;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -80,7 +82,7 @@ public class BiasAdd extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> gradient){
|
||||
return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0)));
|
||||
return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0), nchw));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -31,9 +31,12 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
|
||||
public class BiasAddGrad extends DynamicCustomOp {
|
||||
protected boolean nchw = true;
|
||||
|
||||
public BiasAddGrad(SameDiff sameDiff, SDVariable input, SDVariable bias, SDVariable gradient) {
|
||||
public BiasAddGrad(SameDiff sameDiff, SDVariable input, SDVariable bias, SDVariable gradient, boolean nchw) {
|
||||
super(null, sameDiff, new SDVariable[]{input, bias, gradient});
|
||||
this.nchw = nchw;
|
||||
addBArgument(nchw);
|
||||
}
|
||||
|
||||
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient, INDArray output){
|
||||
|
@ -52,8 +55,6 @@ public class BiasAddGrad extends DynamicCustomOp {
|
|||
return "biasadd_bp";
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Differentiation not supported for op " + getClass().getSimpleName());
|
||||
|
|
|
@ -39,6 +39,7 @@ import java.util.*;
|
|||
@Slf4j
|
||||
public class Concat extends DynamicCustomOp {
|
||||
private int concatDimension = -1;
|
||||
private boolean isDynamicAxis = false;
|
||||
|
||||
public Concat(){
|
||||
|
||||
|
@ -83,73 +84,11 @@ public class Concat extends DynamicCustomOp {
|
|||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
|
||||
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
|
||||
|
||||
Map<String,PropertyMapping> concatMap = new HashMap<>();
|
||||
val concatDimProps = PropertyMapping.builder()
|
||||
.tfInputPosition(0)
|
||||
.onnxAttrName("axis")
|
||||
.build();
|
||||
concatMap.put("concatDimension",concatDimProps);
|
||||
|
||||
|
||||
Map<String,PropertyMapping> concatV2Map = new HashMap<>();
|
||||
val concat2DimProps = PropertyMapping.builder()
|
||||
//lalst position
|
||||
.tfInputPosition(-1)
|
||||
.onnxAttrName("axis")
|
||||
.build();
|
||||
concatV2Map.put("concatDimension",concat2DimProps);
|
||||
|
||||
//note that onnx is already covered here
|
||||
ret.put(tensorflowNames()[0],concatMap);
|
||||
ret.put(tensorflowNames()[1],concatV2Map);
|
||||
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
int concatDimension = -1;
|
||||
String input = null;
|
||||
val inputCount = nodeDef.getInputCount();
|
||||
for(int i = 0; i < inputCount; i++) {
|
||||
if(nodeDef.getInput(i).contains("/concat_dim")) {
|
||||
input = nodeDef.getInput(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
//older versions may specify a concat_dim, usually it's the last argument
|
||||
if(input == null) {
|
||||
input = nodeDef.getInput(nodeDef.getInputCount() - 1);
|
||||
}
|
||||
|
||||
val variable = initWith.getVariable(input);
|
||||
// concat dimension is only possible
|
||||
if (variable != null) {
|
||||
val arr = variable.getArr();
|
||||
if (arr.length() == 1) {
|
||||
concatDimension = arr.getInt(0);
|
||||
}
|
||||
|
||||
this.concatDimension = concatDimension;
|
||||
addIArgument(this.concatDimension);
|
||||
log.trace("Concat dimension: {}", concatDimension);
|
||||
|
||||
}
|
||||
|
||||
//don't pass both iArg and last axis down to libnd4j
|
||||
if(inputArguments().length == nodeDef.getInputCount()) {
|
||||
val inputArgs = inputArguments();
|
||||
removeInputArgument(inputArgs[inputArguments().length - 1]);
|
||||
}
|
||||
|
||||
//TODO Fix this: https://github.com/eclipse/deeplearning4j/issues/8285
|
||||
sameDiff.removeArgFromOp(input,this);
|
||||
//TF uses dynamic axis - last argument is a scalar integer array for axis
|
||||
addBArgument(true);
|
||||
isDynamicAxis = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -159,12 +98,6 @@ public class Concat extends DynamicCustomOp {
|
|||
return ret;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
return "Concat";
|
||||
|
@ -175,7 +108,6 @@ public class Concat extends DynamicCustomOp {
|
|||
return "Concat";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String[] tensorflowNames() {
|
||||
return new String[] {"Concat","ConcatV2"};
|
||||
|
@ -189,18 +121,32 @@ public class Concat extends DynamicCustomOp {
|
|||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
SDVariable[] args = args();
|
||||
SDVariable[] bpArgs = Arrays.copyOf(args, args.length + 1);
|
||||
bpArgs[bpArgs.length-1] = i_v.get(0);
|
||||
return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables());
|
||||
SDVariable[] bpArgs;
|
||||
if(isDynamicAxis){
|
||||
bpArgs = Arrays.copyOf(args, args.length + 2);
|
||||
bpArgs[bpArgs.length - 1] = bpArgs[bpArgs.length - 3]; //Last input is axis -> move to end of bp args too
|
||||
bpArgs[bpArgs.length - 2] = i_v.get(0);
|
||||
return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables());
|
||||
} else {
|
||||
bpArgs = Arrays.copyOf(args, args.length + 1);
|
||||
bpArgs[bpArgs.length - 1] = i_v.get(0);
|
||||
return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
DataType first = dataTypes.get(0);
|
||||
for( int i=1; i<dataTypes.size(); i++ ){
|
||||
|
||||
for( int i=1; i<dataTypes.size() - (isDynamicAxis ? 1 : 0); i++ ){
|
||||
DataType dt = dataTypes.get(i);
|
||||
Preconditions.checkState(first == dt, "All inputs must have same datatype - got %s and %s for inputs 0 and %s respectively", first, dt, i);
|
||||
}
|
||||
if(isDynamicAxis) {
|
||||
Preconditions.checkState(dataTypes.get(dataTypes.size() - 1).isIntType(),
|
||||
"For dynamic axis case, last datatype must be an integer type, got input types %s");
|
||||
}
|
||||
|
||||
//Output type is same as input types
|
||||
return Collections.singletonList(first);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape.bp;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.Onnx;
|
||||
|
@ -42,6 +43,7 @@ import java.util.*;
|
|||
@Slf4j
|
||||
public class ConcatBp extends DynamicCustomOp {
|
||||
private int concatDimension;
|
||||
private boolean dynamicAxis;
|
||||
|
||||
public ConcatBp(){
|
||||
|
||||
|
@ -53,38 +55,30 @@ public class ConcatBp extends DynamicCustomOp {
|
|||
* @param concatDimension
|
||||
* @param inputsAndGrad Original inputs, followed by output gradient
|
||||
*/
|
||||
public ConcatBp(SameDiff sameDiff, int concatDimension, SDVariable... inputsAndGrad){
|
||||
public ConcatBp(@NonNull SameDiff sameDiff, int concatDimension, @NonNull SDVariable... inputsAndGrad){
|
||||
super(null, sameDiff, inputsAndGrad);
|
||||
addIArgument(concatDimension);
|
||||
this.concatDimension = concatDimension;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiff SameDiff instance
|
||||
* @param inputsGradAxis Inputs, gradient array, and axis
|
||||
*/
|
||||
public ConcatBp(@NonNull SameDiff sameDiff, @NonNull SDVariable... inputsGradAxis){
|
||||
super(null, sameDiff, inputsGradAxis);
|
||||
Preconditions.checkState(inputsGradAxis[inputsGradAxis.length-1].dataType().isIntType(),
|
||||
"When using this constructor, the last input must be an integer array (for the axis)");
|
||||
addBArgument(true); //Last argument
|
||||
this.dynamicAxis = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "concat_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
//No op
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Op.Type opType() {
|
||||
return Op.Type.CUSTOM;
|
||||
|
@ -92,7 +86,7 @@ public class ConcatBp extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public int getNumOutputs(){
|
||||
return args().length - 1;
|
||||
return args().length - 1 - (dynamicAxis ? 1 : 0);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1358,4 +1358,35 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
.build());
|
||||
assertEquals(outCC, outFC); //Fails here
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBiasAdd_nchw_nhwc() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
for(boolean nchw : new boolean[]{true, false}) {
|
||||
log.info("Starting test: {}", nchw ? "nchw" : "nhwc");
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
SDVariable in = sameDiff.var("input", Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{2,4,3,3} : new long[]{2,3,3,4}));
|
||||
SDVariable b = sameDiff.var("bias", Nd4j.rand(DataType.DOUBLE, new long[]{4}));
|
||||
|
||||
SDVariable bAdd = sameDiff.nn.biasAdd(in, b, nchw);
|
||||
SDVariable loss = bAdd.std(true);
|
||||
|
||||
|
||||
INDArray exp = in.getArr().dup();
|
||||
if(nchw){
|
||||
exp.addi(b.getArr().reshape(1,4,1,1));
|
||||
} else {
|
||||
exp.addi(b.getArr().reshape(1,1,1,4));
|
||||
}
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(bAdd.name(), exp);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -97,7 +97,10 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
"g_11",
|
||||
|
||||
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
|
||||
"multinomial/.*"
|
||||
"multinomial/.*",
|
||||
|
||||
//2019/11/04 AB - disabled, pending libnd4j deconv3d_tf implementation
|
||||
"conv3d_transpose.*"
|
||||
};
|
||||
|
||||
@BeforeClass
|
||||
|
|
Loading…
Reference in New Issue