parent
2844f8b69a
commit
5e312374d0
|
@ -149,60 +149,53 @@ public class DeConv3D extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
int sD, sH, sW, dD=1, dH=1, dW=1;
|
||||
|
||||
val aStrides = nodeDef.getAttrOrThrow("strides");
|
||||
val tfStrides = aStrides.getList().getIList();
|
||||
int sD, sH, sW, kD, kH, kW;
|
||||
List<Long> tfStrides = aStrides.getList().getIList(); //[mb,c,d,h,w] or [mb,d,h,w,c] depending on format. mb/c are always 1
|
||||
|
||||
val aPadding = nodeDef.getAttrOrDefault("padding", null);
|
||||
|
||||
val paddingMode = aPadding.getS().toStringUtf8();
|
||||
|
||||
val args = args();
|
||||
INDArray arr = sameDiff.getVariable(args[1].name()).getArr();
|
||||
if (arr == null) {
|
||||
arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
|
||||
val varForOp = initWith.getVariable(args[1].name());
|
||||
if (arr != null)
|
||||
initWith.associateArrayWithVariable(arr, varForOp);
|
||||
List<Long> tfDilation = null;
|
||||
if (attributesForNode.containsKey("dilations")) {
|
||||
tfDilation = attributesForNode.get("dilations").getList().getIList(); //[mb,c,d,h,w] or [mb,d,h,w,c] depending on format. mb/c are always 1
|
||||
}
|
||||
|
||||
String dataFormat = "nhwc";
|
||||
val aPadding = nodeDef.getAttrOrDefault("padding", null);
|
||||
String paddingMode = aPadding.getS().toStringUtf8();
|
||||
|
||||
String dataFormat = "NDHWC";
|
||||
if (nodeDef.containsAttr("data_format")) {
|
||||
val attr = nodeDef.getAttrOrThrow("data_format");
|
||||
dataFormat = attr.getS().toStringUtf8().toLowerCase();
|
||||
}
|
||||
|
||||
if (dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW)) {
|
||||
if(dataFormat.equalsIgnoreCase("NCDHW")){
|
||||
sD = tfStrides.get(2).intValue();
|
||||
sH = tfStrides.get(3).intValue();
|
||||
sW = tfStrides.get(4).intValue();
|
||||
|
||||
kD = (int) arr.size(2);
|
||||
kH = (int) arr.size(3);
|
||||
kW = (int) arr.size(4);
|
||||
if(tfDilation != null){
|
||||
dD = tfDilation.get(2).intValue();
|
||||
dH = tfDilation.get(3).intValue();
|
||||
dW = tfDilation.get(4).intValue();
|
||||
}
|
||||
} else {
|
||||
sD = tfStrides.get(1).intValue();
|
||||
sH = tfStrides.get(2).intValue();
|
||||
sW = tfStrides.get(3).intValue();
|
||||
|
||||
kD = (int) arr.size(0);
|
||||
kH = (int) arr.size(1);
|
||||
kW = (int) arr.size(2);
|
||||
if(tfDilation != null){
|
||||
dD = tfDilation.get(1).intValue();
|
||||
dH = tfDilation.get(2).intValue();
|
||||
dW = tfDilation.get(3).intValue();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
|
||||
DeConv3DConfig conv2DConfig = DeConv3DConfig.builder()
|
||||
.kD(kD)
|
||||
.kH(kH)
|
||||
.kW(kW)
|
||||
.sD(sD)
|
||||
.sH(sW)
|
||||
.sW(sH)
|
||||
this.config = DeConv3DConfig.builder()
|
||||
.kD(-1).kH(-1).kW(-1) //Infer from kernel
|
||||
.sD(sD).sH(sW).sW(sH)
|
||||
.dD(dD).dH(dH).dW(dW)
|
||||
.isSameMode(isSameMode)
|
||||
.dataFormat(dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW) ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC)
|
||||
.build();
|
||||
this.config = conv2DConfig;
|
||||
|
||||
addArgs();
|
||||
}
|
||||
|
@ -213,6 +206,10 @@ public class DeConv3D extends DynamicCustomOp {
|
|||
return "deconv3d";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "Conv3DBackpropInputV2";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
|
|
|
@ -97,7 +97,10 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
"g_11",
|
||||
|
||||
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
|
||||
"multinomial/.*"
|
||||
"multinomial/.*",
|
||||
|
||||
//2019/11/02 AB - need deconv3d changes (for handling shape)
|
||||
"conv3d_transpose.*"
|
||||
};
|
||||
|
||||
@BeforeClass
|
||||
|
|
Loading…
Reference in New Issue