TF deconv3d import (#8341)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-02 22:57:24 +11:00 committed by GitHub
parent 2844f8b69a
commit 5e312374d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 33 deletions

View File

@ -149,60 +149,53 @@ public class DeConv3D extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
int sD, sH, sW, dD=1, dH=1, dW=1;
val aStrides = nodeDef.getAttrOrThrow("strides"); val aStrides = nodeDef.getAttrOrThrow("strides");
val tfStrides = aStrides.getList().getIList(); List<Long> tfStrides = aStrides.getList().getIList(); //[mb,c,d,h,w] or [mb,d,h,w,c] depending on format. mb/c are always 1
int sD, sH, sW, kD, kH, kW;
val aPadding = nodeDef.getAttrOrDefault("padding", null); List<Long> tfDilation = null;
if (attributesForNode.containsKey("dilations")) {
val paddingMode = aPadding.getS().toStringUtf8(); tfDilation = attributesForNode.get("dilations").getList().getIList(); //[mb,c,d,h,w] or [mb,d,h,w,c] depending on format. mb/c are always 1
val args = args();
INDArray arr = sameDiff.getVariable(args[1].name()).getArr();
if (arr == null) {
arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
val varForOp = initWith.getVariable(args[1].name());
if (arr != null)
initWith.associateArrayWithVariable(arr, varForOp);
} }
String dataFormat = "nhwc"; val aPadding = nodeDef.getAttrOrDefault("padding", null);
String paddingMode = aPadding.getS().toStringUtf8();
String dataFormat = "NDHWC";
if (nodeDef.containsAttr("data_format")) { if (nodeDef.containsAttr("data_format")) {
val attr = nodeDef.getAttrOrThrow("data_format"); val attr = nodeDef.getAttrOrThrow("data_format");
dataFormat = attr.getS().toStringUtf8().toLowerCase(); dataFormat = attr.getS().toStringUtf8().toLowerCase();
} }
if (dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW)) { if(dataFormat.equalsIgnoreCase("NCDHW")){
sD = tfStrides.get(2).intValue(); sD = tfStrides.get(2).intValue();
sH = tfStrides.get(3).intValue(); sH = tfStrides.get(3).intValue();
sW = tfStrides.get(4).intValue(); sW = tfStrides.get(4).intValue();
if(tfDilation != null){
kD = (int) arr.size(2); dD = tfDilation.get(2).intValue();
kH = (int) arr.size(3); dH = tfDilation.get(3).intValue();
kW = (int) arr.size(4); dW = tfDilation.get(4).intValue();
}
} else { } else {
sD = tfStrides.get(1).intValue(); sD = tfStrides.get(1).intValue();
sH = tfStrides.get(2).intValue(); sH = tfStrides.get(2).intValue();
sW = tfStrides.get(3).intValue(); sW = tfStrides.get(3).intValue();
if(tfDilation != null){
kD = (int) arr.size(0); dD = tfDilation.get(1).intValue();
kH = (int) arr.size(1); dH = tfDilation.get(2).intValue();
kW = (int) arr.size(2); dW = tfDilation.get(3).intValue();
}
} }
boolean isSameMode = paddingMode.equalsIgnoreCase("SAME"); boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
DeConv3DConfig conv2DConfig = DeConv3DConfig.builder() this.config = DeConv3DConfig.builder()
.kD(kD) .kD(-1).kH(-1).kW(-1) //Infer from kernel
.kH(kH) .sD(sD).sH(sW).sW(sH)
.kW(kW) .dD(dD).dH(dH).dW(dW)
.sD(sD)
.sH(sW)
.sW(sH)
.isSameMode(isSameMode) .isSameMode(isSameMode)
.dataFormat(dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW) ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC) .dataFormat(dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW) ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC)
.build(); .build();
this.config = conv2DConfig;
addArgs(); addArgs();
} }
@ -213,6 +206,10 @@ public class DeConv3D extends DynamicCustomOp {
return "deconv3d"; return "deconv3d";
} }
@Override
public String tensorflowName() {
return "Conv3DBackpropInputV2";
}
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> f1) { public List<SDVariable> doDiff(List<SDVariable> f1) {

View File

@ -97,7 +97,10 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
"g_11", "g_11",
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913 //2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
"multinomial/.*" "multinomial/.*",
//2019/11/02 AB - need deconv3d changes (for handling shape)
"conv3d_transpose.*"
}; };
@BeforeClass @BeforeClass