parent
2844f8b69a
commit
5e312374d0
|
@ -149,60 +149,53 @@ public class DeConv3D extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
|
int sD, sH, sW, dD=1, dH=1, dW=1;
|
||||||
|
|
||||||
val aStrides = nodeDef.getAttrOrThrow("strides");
|
val aStrides = nodeDef.getAttrOrThrow("strides");
|
||||||
val tfStrides = aStrides.getList().getIList();
|
List<Long> tfStrides = aStrides.getList().getIList(); //[mb,c,d,h,w] or [mb,d,h,w,c] depending on format. mb/c are always 1
|
||||||
int sD, sH, sW, kD, kH, kW;
|
|
||||||
|
|
||||||
val aPadding = nodeDef.getAttrOrDefault("padding", null);
|
List<Long> tfDilation = null;
|
||||||
|
if (attributesForNode.containsKey("dilations")) {
|
||||||
val paddingMode = aPadding.getS().toStringUtf8();
|
tfDilation = attributesForNode.get("dilations").getList().getIList(); //[mb,c,d,h,w] or [mb,d,h,w,c] depending on format. mb/c are always 1
|
||||||
|
|
||||||
val args = args();
|
|
||||||
INDArray arr = sameDiff.getVariable(args[1].name()).getArr();
|
|
||||||
if (arr == null) {
|
|
||||||
arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
|
|
||||||
val varForOp = initWith.getVariable(args[1].name());
|
|
||||||
if (arr != null)
|
|
||||||
initWith.associateArrayWithVariable(arr, varForOp);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
String dataFormat = "nhwc";
|
val aPadding = nodeDef.getAttrOrDefault("padding", null);
|
||||||
|
String paddingMode = aPadding.getS().toStringUtf8();
|
||||||
|
|
||||||
|
String dataFormat = "NDHWC";
|
||||||
if (nodeDef.containsAttr("data_format")) {
|
if (nodeDef.containsAttr("data_format")) {
|
||||||
val attr = nodeDef.getAttrOrThrow("data_format");
|
val attr = nodeDef.getAttrOrThrow("data_format");
|
||||||
dataFormat = attr.getS().toStringUtf8().toLowerCase();
|
dataFormat = attr.getS().toStringUtf8().toLowerCase();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW)) {
|
if(dataFormat.equalsIgnoreCase("NCDHW")){
|
||||||
sD = tfStrides.get(2).intValue();
|
sD = tfStrides.get(2).intValue();
|
||||||
sH = tfStrides.get(3).intValue();
|
sH = tfStrides.get(3).intValue();
|
||||||
sW = tfStrides.get(4).intValue();
|
sW = tfStrides.get(4).intValue();
|
||||||
|
if(tfDilation != null){
|
||||||
kD = (int) arr.size(2);
|
dD = tfDilation.get(2).intValue();
|
||||||
kH = (int) arr.size(3);
|
dH = tfDilation.get(3).intValue();
|
||||||
kW = (int) arr.size(4);
|
dW = tfDilation.get(4).intValue();
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
sD = tfStrides.get(1).intValue();
|
sD = tfStrides.get(1).intValue();
|
||||||
sH = tfStrides.get(2).intValue();
|
sH = tfStrides.get(2).intValue();
|
||||||
sW = tfStrides.get(3).intValue();
|
sW = tfStrides.get(3).intValue();
|
||||||
|
if(tfDilation != null){
|
||||||
kD = (int) arr.size(0);
|
dD = tfDilation.get(1).intValue();
|
||||||
kH = (int) arr.size(1);
|
dH = tfDilation.get(2).intValue();
|
||||||
kW = (int) arr.size(2);
|
dW = tfDilation.get(3).intValue();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
|
boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
|
||||||
DeConv3DConfig conv2DConfig = DeConv3DConfig.builder()
|
this.config = DeConv3DConfig.builder()
|
||||||
.kD(kD)
|
.kD(-1).kH(-1).kW(-1) //Infer from kernel
|
||||||
.kH(kH)
|
.sD(sD).sH(sW).sW(sH)
|
||||||
.kW(kW)
|
.dD(dD).dH(dH).dW(dW)
|
||||||
.sD(sD)
|
|
||||||
.sH(sW)
|
|
||||||
.sW(sH)
|
|
||||||
.isSameMode(isSameMode)
|
.isSameMode(isSameMode)
|
||||||
.dataFormat(dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW) ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC)
|
.dataFormat(dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW) ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC)
|
||||||
.build();
|
.build();
|
||||||
this.config = conv2DConfig;
|
|
||||||
|
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
@ -213,6 +206,10 @@ public class DeConv3D extends DynamicCustomOp {
|
||||||
return "deconv3d";
|
return "deconv3d";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "Conv3DBackpropInputV2";
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
|
|
@ -97,7 +97,10 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
"g_11",
|
"g_11",
|
||||||
|
|
||||||
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
|
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
|
||||||
"multinomial/.*"
|
"multinomial/.*",
|
||||||
|
|
||||||
|
//2019/11/02 AB - need deconv3d changes (for handling shape)
|
||||||
|
"conv3d_transpose.*"
|
||||||
};
|
};
|
||||||
|
|
||||||
@BeforeClass
|
@BeforeClass
|
||||||
|
|
Loading…
Reference in New Issue