TF deconv3d import (#8341)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-02 22:57:24 +11:00 committed by GitHub
parent 2844f8b69a
commit 5e312374d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 33 deletions

View File

@ -149,60 +149,53 @@ public class DeConv3D extends DynamicCustomOp {
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
int sD, sH, sW, dD=1, dH=1, dW=1;
val aStrides = nodeDef.getAttrOrThrow("strides");
val tfStrides = aStrides.getList().getIList();
int sD, sH, sW, kD, kH, kW;
List<Long> tfStrides = aStrides.getList().getIList(); //[mb,c,d,h,w] or [mb,d,h,w,c] depending on format. mb/c are always 1
val aPadding = nodeDef.getAttrOrDefault("padding", null);
val paddingMode = aPadding.getS().toStringUtf8();
val args = args();
INDArray arr = sameDiff.getVariable(args[1].name()).getArr();
if (arr == null) {
arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
val varForOp = initWith.getVariable(args[1].name());
if (arr != null)
initWith.associateArrayWithVariable(arr, varForOp);
List<Long> tfDilation = null;
if (attributesForNode.containsKey("dilations")) {
tfDilation = attributesForNode.get("dilations").getList().getIList(); //[mb,c,d,h,w] or [mb,d,h,w,c] depending on format. mb/c are always 1
}
String dataFormat = "nhwc";
val aPadding = nodeDef.getAttrOrDefault("padding", null);
String paddingMode = aPadding.getS().toStringUtf8();
String dataFormat = "NDHWC";
if (nodeDef.containsAttr("data_format")) {
val attr = nodeDef.getAttrOrThrow("data_format");
dataFormat = attr.getS().toStringUtf8().toLowerCase();
}
if (dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW)) {
if(dataFormat.equalsIgnoreCase("NCDHW")){
sD = tfStrides.get(2).intValue();
sH = tfStrides.get(3).intValue();
sW = tfStrides.get(4).intValue();
kD = (int) arr.size(2);
kH = (int) arr.size(3);
kW = (int) arr.size(4);
if(tfDilation != null){
dD = tfDilation.get(2).intValue();
dH = tfDilation.get(3).intValue();
dW = tfDilation.get(4).intValue();
}
} else {
sD = tfStrides.get(1).intValue();
sH = tfStrides.get(2).intValue();
sW = tfStrides.get(3).intValue();
kD = (int) arr.size(0);
kH = (int) arr.size(1);
kW = (int) arr.size(2);
if(tfDilation != null){
dD = tfDilation.get(1).intValue();
dH = tfDilation.get(2).intValue();
dW = tfDilation.get(3).intValue();
}
}
boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
DeConv3DConfig conv2DConfig = DeConv3DConfig.builder()
.kD(kD)
.kH(kH)
.kW(kW)
.sD(sD)
.sH(sW)
.sW(sH)
this.config = DeConv3DConfig.builder()
.kD(-1).kH(-1).kW(-1) //Infer from kernel
.sD(sD).sH(sW).sW(sH)
.dD(dD).dH(dH).dW(dW)
.isSameMode(isSameMode)
.dataFormat(dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW) ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC)
.build();
this.config = conv2DConfig;
addArgs();
}
@ -213,6 +206,10 @@ public class DeConv3D extends DynamicCustomOp {
return "deconv3d";
}
@Override
public String tensorflowName() {
return "Conv3DBackpropInputV2";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {

View File

@ -97,7 +97,10 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
"g_11",
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
"multinomial/.*"
"multinomial/.*",
//2019/11/02 AB - need deconv3d changes (for handling shape)
"conv3d_transpose.*"
};
@BeforeClass