Batch Mmul v2 import fix (#41)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-07-03 12:37:14 +10:00 committed by AlexDBlack
parent d854e28b34
commit 8b4b8977a0
2 changed files with 9 additions and 6 deletions

View File

@ -945,7 +945,8 @@ public class TFGraphMapper extends BaseGraphMapper<GraphDef,NodeDef,AttrValue,No
} else if(tensorProto.getOp().equals("Assert")){
return org.nd4j.linalg.api.buffer.DataType.BOOL;
}
log.warn("No TensorFlow descriptor found for tensor \"{}\", op \"{}\"", tensorProto.getName(), tensorProto.getOp());
//Not in ops.proto
log.debug("No TensorFlow descriptor found for tensor \"{}\", op \"{}\"", tensorProto.getName(), tensorProto.getOp());
//No descriptor... try to fall back on common type attribute names
if(!tensorProto.containsAttr("dtype") && !tensorProto.containsAttr("Tidx") && !tensorProto.containsAttr("T"))

View File

@ -153,7 +153,7 @@ public class Mmul extends DynamicCustomOp {
@Override
public String[] tensorflowNames() {
return new String[]{"MatMul", "BatchMatMul"};
return new String[]{"MatMul", "BatchMatMul", "BatchMatMulV2"};
}
@ -171,7 +171,12 @@ public class Mmul extends DynamicCustomOp {
boolean isTransposeA;
boolean isTransposeB;
if(nodeDef.getOp().equalsIgnoreCase("BatchMatMul")){
if(nodeDef.getOp().equalsIgnoreCase("MatMul")){
isTransposeA = attributesForNode.get("transpose_a").getB();
isTransposeB = attributesForNode.get("transpose_b").getB();
} else {
//BatchMatMul, BatchMatMulV2
//In practice, BatchMatMul seems to use "adj_x" and "adj_y" instead of "transpose_a" and "transpose_b"
if(attributesForNode.containsKey("transpose_a")){
isTransposeA = attributesForNode.get("transpose_a").getB();
@ -183,9 +188,6 @@ public class Mmul extends DynamicCustomOp {
} else {
isTransposeB = attributesForNode.get("adj_y").getB();
}
} else {
isTransposeA = attributesForNode.get("transpose_a").getB();
isTransposeB = attributesForNode.get("transpose_b").getB();
}
MMulTranspose mMulTranspose = MMulTranspose.builder()
.transposeA(isTransposeA).transposeB(isTransposeB)