diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java index 0a153422d..a9b9b8fa8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java @@ -149,60 +149,53 @@ public class DeConv3D extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + int sD, sH, sW, dD=1, dH=1, dW=1; + val aStrides = nodeDef.getAttrOrThrow("strides"); - val tfStrides = aStrides.getList().getIList(); - int sD, sH, sW, kD, kH, kW; + List tfStrides = aStrides.getList().getIList(); //[mb,c,d,h,w] or [mb,d,h,w,c] depending on format. mb/c are always 1 - val aPadding = nodeDef.getAttrOrDefault("padding", null); - - val paddingMode = aPadding.getS().toStringUtf8(); - - val args = args(); - INDArray arr = sameDiff.getVariable(args[1].name()).getArr(); - if (arr == null) { - arr = TFGraphMapper.getNDArrayFromTensor(nodeDef); - val varForOp = initWith.getVariable(args[1].name()); - if (arr != null) - initWith.associateArrayWithVariable(arr, varForOp); + List tfDilation = null; + if (attributesForNode.containsKey("dilations")) { + tfDilation = attributesForNode.get("dilations").getList().getIList(); //[mb,c,d,h,w] or [mb,d,h,w,c] depending on format. mb/c are always 1 } - String dataFormat = "nhwc"; + val aPadding = nodeDef.getAttrOrDefault("padding", null); + String paddingMode = aPadding.getS().toStringUtf8(); + + String dataFormat = "NDHWC"; if (nodeDef.containsAttr("data_format")) { val attr = nodeDef.getAttrOrThrow("data_format"); dataFormat = attr.getS().toStringUtf8().toLowerCase(); } - if (dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW)) { + if(dataFormat.equalsIgnoreCase("NCDHW")){ sD = tfStrides.get(2).intValue(); sH = tfStrides.get(3).intValue(); sW = tfStrides.get(4).intValue(); - - kD = (int) arr.size(2); - kH = (int) arr.size(3); - kW = (int) arr.size(4); + if(tfDilation != null){ + dD = tfDilation.get(2).intValue(); + dH = tfDilation.get(3).intValue(); + dW = tfDilation.get(4).intValue(); + } } else { sD = tfStrides.get(1).intValue(); sH = tfStrides.get(2).intValue(); sW = tfStrides.get(3).intValue(); - - kD = (int) arr.size(0); - kH = (int) arr.size(1); - kW = (int) arr.size(2); + if(tfDilation != null){ + dD = tfDilation.get(1).intValue(); + dH = tfDilation.get(2).intValue(); + dW = tfDilation.get(3).intValue(); + } } - boolean isSameMode = paddingMode.equalsIgnoreCase("SAME"); - DeConv3DConfig conv2DConfig = DeConv3DConfig.builder() - .kD(kD) - .kH(kH) - .kW(kW) - .sD(sD) - .sH(sW) - .sW(sH) + this.config = DeConv3DConfig.builder() + .kD(-1).kH(-1).kW(-1) //Infer from kernel + .sD(sD).sH(sW).sW(sH) + .dD(dD).dH(dH).dW(dW) .isSameMode(isSameMode) .dataFormat(dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW) ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC) .build(); - this.config = conv2DConfig; addArgs(); } @@ -213,6 +206,10 @@ public class DeConv3D extends DynamicCustomOp { return "deconv3d"; } + @Override + public String tensorflowName() { + return "Conv3DBackpropInputV2"; + } @Override public List doDiff(List f1) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index fb420ac4b..4befa1b8c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -97,7 +97,10 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a "g_11", //2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913 - "multinomial/.*" + "multinomial/.*", + + //2019/11/02 AB - need deconv3d changes (for handling shape) + "conv3d_transpose.*" }; @BeforeClass