Op Fixes (#28)
* #8280 biasadd_bp nchw arg fixes (java side) + test Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8285 Concat op Java side fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Concat op cpp fix - allow dynamic axis to be negative, same as static axis Signed-off-by: AlexDBlack <blacka101@gmail.com> * ignores for deconv3d import tests until deconv3d_tf op is implemented Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
4763547c9e
commit
948ebef41c
|
@ -78,7 +78,10 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array
|
const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array
|
||||||
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank);
|
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
|
||||||
|
if(axis < 0){
|
||||||
|
axis += rank;
|
||||||
|
}
|
||||||
|
|
||||||
// ******** input validation ******** //
|
// ******** input validation ******** //
|
||||||
REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
|
REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
|
||||||
|
@ -150,7 +153,10 @@ DECLARE_SHAPE_FN(concat) {
|
||||||
|
|
||||||
const int rank = arrShapes[0][0];
|
const int rank = arrShapes[0][0];
|
||||||
|
|
||||||
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank);
|
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
|
||||||
|
if(axis < 0){
|
||||||
|
axis += rank;
|
||||||
|
}
|
||||||
|
|
||||||
// ******** input validation ******** //
|
// ******** input validation ******** //
|
||||||
REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
||||||
|
|
|
@ -975,8 +975,8 @@ public class DifferentialFunctionFactory {
|
||||||
return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable();
|
return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad) {
|
public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad, boolean nchw) {
|
||||||
return new BiasAddGrad(sameDiff(), input, bias, grad).outputVariables();
|
return new BiasAddGrad(sameDiff(), input, bias, grad, nchw).outputVariables();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable norm1(SDVariable i_x, boolean keepDims, int... dimensions) {
|
public SDVariable norm1(SDVariable i_x, boolean keepDims, int... dimensions) {
|
||||||
|
|
|
@ -45,12 +45,14 @@ public class BiasAdd extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[] {input, bias}, false);
|
super(null, sameDiff, new SDVariable[] {input, bias}, false);
|
||||||
bArguments.clear();
|
bArguments.clear();
|
||||||
bArguments.add(nchw);
|
bArguments.add(nchw);
|
||||||
|
this.nchw = nchw;
|
||||||
}
|
}
|
||||||
|
|
||||||
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){
|
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){
|
||||||
super(new INDArray[]{input, bias}, wrapOrNull(output));
|
super(new INDArray[]{input, bias}, wrapOrNull(output));
|
||||||
bArguments.clear();
|
bArguments.clear();
|
||||||
bArguments.add(nchw);
|
bArguments.add(nchw);
|
||||||
|
this.nchw = nchw;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -80,7 +82,7 @@ public class BiasAdd extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> gradient){
|
public List<SDVariable> doDiff(List<SDVariable> gradient){
|
||||||
return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0)));
|
return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0), nchw));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -31,9 +31,12 @@ import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class BiasAddGrad extends DynamicCustomOp {
|
public class BiasAddGrad extends DynamicCustomOp {
|
||||||
|
protected boolean nchw = true;
|
||||||
|
|
||||||
public BiasAddGrad(SameDiff sameDiff, SDVariable input, SDVariable bias, SDVariable gradient) {
|
public BiasAddGrad(SameDiff sameDiff, SDVariable input, SDVariable bias, SDVariable gradient, boolean nchw) {
|
||||||
super(null, sameDiff, new SDVariable[]{input, bias, gradient});
|
super(null, sameDiff, new SDVariable[]{input, bias, gradient});
|
||||||
|
this.nchw = nchw;
|
||||||
|
addBArgument(nchw);
|
||||||
}
|
}
|
||||||
|
|
||||||
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient, INDArray output){
|
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient, INDArray output){
|
||||||
|
@ -52,8 +55,6 @@ public class BiasAddGrad extends DynamicCustomOp {
|
||||||
return "biasadd_bp";
|
return "biasadd_bp";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
throw new UnsupportedOperationException("Differentiation not supported for op " + getClass().getSimpleName());
|
throw new UnsupportedOperationException("Differentiation not supported for op " + getClass().getSimpleName());
|
||||||
|
|
|
@ -39,6 +39,7 @@ import java.util.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class Concat extends DynamicCustomOp {
|
public class Concat extends DynamicCustomOp {
|
||||||
private int concatDimension = -1;
|
private int concatDimension = -1;
|
||||||
|
private boolean isDynamicAxis = false;
|
||||||
|
|
||||||
public Concat(){
|
public Concat(){
|
||||||
|
|
||||||
|
@ -83,73 +84,11 @@ public class Concat extends DynamicCustomOp {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
|
|
||||||
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
|
|
||||||
|
|
||||||
Map<String,PropertyMapping> concatMap = new HashMap<>();
|
|
||||||
val concatDimProps = PropertyMapping.builder()
|
|
||||||
.tfInputPosition(0)
|
|
||||||
.onnxAttrName("axis")
|
|
||||||
.build();
|
|
||||||
concatMap.put("concatDimension",concatDimProps);
|
|
||||||
|
|
||||||
|
|
||||||
Map<String,PropertyMapping> concatV2Map = new HashMap<>();
|
|
||||||
val concat2DimProps = PropertyMapping.builder()
|
|
||||||
//lalst position
|
|
||||||
.tfInputPosition(-1)
|
|
||||||
.onnxAttrName("axis")
|
|
||||||
.build();
|
|
||||||
concatV2Map.put("concatDimension",concat2DimProps);
|
|
||||||
|
|
||||||
//note that onnx is already covered here
|
|
||||||
ret.put(tensorflowNames()[0],concatMap);
|
|
||||||
ret.put(tensorflowNames()[1],concatV2Map);
|
|
||||||
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
int concatDimension = -1;
|
//TF uses dynamic axis - last argument is a scalar integer array for axis
|
||||||
String input = null;
|
addBArgument(true);
|
||||||
val inputCount = nodeDef.getInputCount();
|
isDynamicAxis = true;
|
||||||
for(int i = 0; i < inputCount; i++) {
|
|
||||||
if(nodeDef.getInput(i).contains("/concat_dim")) {
|
|
||||||
input = nodeDef.getInput(i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//older versions may specify a concat_dim, usually it's the last argument
|
|
||||||
if(input == null) {
|
|
||||||
input = nodeDef.getInput(nodeDef.getInputCount() - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
val variable = initWith.getVariable(input);
|
|
||||||
// concat dimension is only possible
|
|
||||||
if (variable != null) {
|
|
||||||
val arr = variable.getArr();
|
|
||||||
if (arr.length() == 1) {
|
|
||||||
concatDimension = arr.getInt(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.concatDimension = concatDimension;
|
|
||||||
addIArgument(this.concatDimension);
|
|
||||||
log.trace("Concat dimension: {}", concatDimension);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
//don't pass both iArg and last axis down to libnd4j
|
|
||||||
if(inputArguments().length == nodeDef.getInputCount()) {
|
|
||||||
val inputArgs = inputArguments();
|
|
||||||
removeInputArgument(inputArgs[inputArguments().length - 1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
//TODO Fix this: https://github.com/eclipse/deeplearning4j/issues/8285
|
|
||||||
sameDiff.removeArgFromOp(input,this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -159,12 +98,6 @@ public class Concat extends DynamicCustomOp {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String onnxName() {
|
public String onnxName() {
|
||||||
return "Concat";
|
return "Concat";
|
||||||
|
@ -175,7 +108,6 @@ public class Concat extends DynamicCustomOp {
|
||||||
return "Concat";
|
return "Concat";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String[] tensorflowNames() {
|
public String[] tensorflowNames() {
|
||||||
return new String[] {"Concat","ConcatV2"};
|
return new String[] {"Concat","ConcatV2"};
|
||||||
|
@ -189,18 +121,32 @@ public class Concat extends DynamicCustomOp {
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
SDVariable[] args = args();
|
SDVariable[] args = args();
|
||||||
SDVariable[] bpArgs = Arrays.copyOf(args, args.length + 1);
|
SDVariable[] bpArgs;
|
||||||
|
if(isDynamicAxis){
|
||||||
|
bpArgs = Arrays.copyOf(args, args.length + 2);
|
||||||
|
bpArgs[bpArgs.length - 1] = bpArgs[bpArgs.length - 3]; //Last input is axis -> move to end of bp args too
|
||||||
|
bpArgs[bpArgs.length - 2] = i_v.get(0);
|
||||||
|
return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables());
|
||||||
|
} else {
|
||||||
|
bpArgs = Arrays.copyOf(args, args.length + 1);
|
||||||
bpArgs[bpArgs.length - 1] = i_v.get(0);
|
bpArgs[bpArgs.length - 1] = i_v.get(0);
|
||||||
return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables());
|
return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
DataType first = dataTypes.get(0);
|
DataType first = dataTypes.get(0);
|
||||||
for( int i=1; i<dataTypes.size(); i++ ){
|
|
||||||
|
for( int i=1; i<dataTypes.size() - (isDynamicAxis ? 1 : 0); i++ ){
|
||||||
DataType dt = dataTypes.get(i);
|
DataType dt = dataTypes.get(i);
|
||||||
Preconditions.checkState(first == dt, "All inputs must have same datatype - got %s and %s for inputs 0 and %s respectively", first, dt, i);
|
Preconditions.checkState(first == dt, "All inputs must have same datatype - got %s and %s for inputs 0 and %s respectively", first, dt, i);
|
||||||
}
|
}
|
||||||
|
if(isDynamicAxis) {
|
||||||
|
Preconditions.checkState(dataTypes.get(dataTypes.size() - 1).isIntType(),
|
||||||
|
"For dynamic axis case, last datatype must be an integer type, got input types %s");
|
||||||
|
}
|
||||||
|
|
||||||
//Output type is same as input types
|
//Output type is same as input types
|
||||||
return Collections.singletonList(first);
|
return Collections.singletonList(first);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape.bp;
|
package org.nd4j.linalg.api.ops.impl.shape.bp;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -42,6 +43,7 @@ import java.util.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ConcatBp extends DynamicCustomOp {
|
public class ConcatBp extends DynamicCustomOp {
|
||||||
private int concatDimension;
|
private int concatDimension;
|
||||||
|
private boolean dynamicAxis;
|
||||||
|
|
||||||
public ConcatBp(){
|
public ConcatBp(){
|
||||||
|
|
||||||
|
@ -53,38 +55,30 @@ public class ConcatBp extends DynamicCustomOp {
|
||||||
* @param concatDimension
|
* @param concatDimension
|
||||||
* @param inputsAndGrad Original inputs, followed by output gradient
|
* @param inputsAndGrad Original inputs, followed by output gradient
|
||||||
*/
|
*/
|
||||||
public ConcatBp(SameDiff sameDiff, int concatDimension, SDVariable... inputsAndGrad){
|
public ConcatBp(@NonNull SameDiff sameDiff, int concatDimension, @NonNull SDVariable... inputsAndGrad){
|
||||||
super(null, sameDiff, inputsAndGrad);
|
super(null, sameDiff, inputsAndGrad);
|
||||||
addIArgument(concatDimension);
|
addIArgument(concatDimension);
|
||||||
this.concatDimension = concatDimension;
|
this.concatDimension = concatDimension;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param sameDiff SameDiff instance
|
||||||
|
* @param inputsGradAxis Inputs, gradient array, and axis
|
||||||
|
*/
|
||||||
|
public ConcatBp(@NonNull SameDiff sameDiff, @NonNull SDVariable... inputsGradAxis){
|
||||||
|
super(null, sameDiff, inputsGradAxis);
|
||||||
|
Preconditions.checkState(inputsGradAxis[inputsGradAxis.length-1].dataType().isIntType(),
|
||||||
|
"When using this constructor, the last input must be an integer array (for the axis)");
|
||||||
|
addBArgument(true); //Last argument
|
||||||
|
this.dynamicAxis = true;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "concat_bp";
|
return "concat_bp";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
|
||||||
//No op
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
|
||||||
//No op
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Op.Type opType() {
|
public Op.Type opType() {
|
||||||
return Op.Type.CUSTOM;
|
return Op.Type.CUSTOM;
|
||||||
|
@ -92,7 +86,7 @@ public class ConcatBp extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getNumOutputs(){
|
public int getNumOutputs(){
|
||||||
return args().length - 1;
|
return args().length - 1 - (dynamicAxis ? 1 : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1358,4 +1358,35 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build());
|
.build());
|
||||||
assertEquals(outCC, outFC); //Fails here
|
assertEquals(outCC, outFC); //Fails here
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBiasAdd_nchw_nhwc() {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
for(boolean nchw : new boolean[]{true, false}) {
|
||||||
|
log.info("Starting test: {}", nchw ? "nchw" : "nhwc");
|
||||||
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
|
||||||
|
SDVariable in = sameDiff.var("input", Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{2,4,3,3} : new long[]{2,3,3,4}));
|
||||||
|
SDVariable b = sameDiff.var("bias", Nd4j.rand(DataType.DOUBLE, new long[]{4}));
|
||||||
|
|
||||||
|
SDVariable bAdd = sameDiff.nn.biasAdd(in, b, nchw);
|
||||||
|
SDVariable loss = bAdd.std(true);
|
||||||
|
|
||||||
|
|
||||||
|
INDArray exp = in.getArr().dup();
|
||||||
|
if(nchw){
|
||||||
|
exp.addi(b.getArr().reshape(1,4,1,1));
|
||||||
|
} else {
|
||||||
|
exp.addi(b.getArr().reshape(1,1,1,4));
|
||||||
|
}
|
||||||
|
|
||||||
|
TestCase tc = new TestCase(sameDiff)
|
||||||
|
.gradientCheck(true)
|
||||||
|
.expectedOutput(bAdd.name(), exp);
|
||||||
|
|
||||||
|
String err = OpValidation.validate(tc);
|
||||||
|
assertNull(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -97,7 +97,10 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
"g_11",
|
"g_11",
|
||||||
|
|
||||||
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
|
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
|
||||||
"multinomial/.*"
|
"multinomial/.*",
|
||||||
|
|
||||||
|
//2019/11/04 AB - disabled, pending libnd4j deconv3d_tf implementation
|
||||||
|
"conv3d_transpose.*"
|
||||||
};
|
};
|
||||||
|
|
||||||
@BeforeClass
|
@BeforeClass
|
||||||
|
|
Loading…
Reference in New Issue