parent
d854e28b34
commit
8b4b8977a0
|
@ -945,7 +945,8 @@ public class TFGraphMapper extends BaseGraphMapper<GraphDef,NodeDef,AttrValue,No
|
|||
} else if(tensorProto.getOp().equals("Assert")){
|
||||
return org.nd4j.linalg.api.buffer.DataType.BOOL;
|
||||
}
|
||||
log.warn("No TensorFlow descriptor found for tensor \"{}\", op \"{}\"", tensorProto.getName(), tensorProto.getOp());
|
||||
//Not in ops.proto
|
||||
log.debug("No TensorFlow descriptor found for tensor \"{}\", op \"{}\"", tensorProto.getName(), tensorProto.getOp());
|
||||
|
||||
//No descriptor... try to fall back on common type attribute names
|
||||
if(!tensorProto.containsAttr("dtype") && !tensorProto.containsAttr("Tidx") && !tensorProto.containsAttr("T"))
|
||||
|
|
|
@ -153,7 +153,7 @@ public class Mmul extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"MatMul", "BatchMatMul"};
|
||||
return new String[]{"MatMul", "BatchMatMul", "BatchMatMulV2"};
|
||||
}
|
||||
|
||||
|
||||
|
@ -171,7 +171,12 @@ public class Mmul extends DynamicCustomOp {
|
|||
|
||||
boolean isTransposeA;
|
||||
boolean isTransposeB;
|
||||
if(nodeDef.getOp().equalsIgnoreCase("BatchMatMul")){
|
||||
if(nodeDef.getOp().equalsIgnoreCase("MatMul")){
|
||||
isTransposeA = attributesForNode.get("transpose_a").getB();
|
||||
isTransposeB = attributesForNode.get("transpose_b").getB();
|
||||
|
||||
} else {
|
||||
//BatchMatMul, BatchMatMulV2
|
||||
//In practice, BatchMatMul seems to use "adj_x" and "adj_y" instead of "transpose_a" and "transpose_b"
|
||||
if(attributesForNode.containsKey("transpose_a")){
|
||||
isTransposeA = attributesForNode.get("transpose_a").getB();
|
||||
|
@ -183,9 +188,6 @@ public class Mmul extends DynamicCustomOp {
|
|||
} else {
|
||||
isTransposeB = attributesForNode.get("adj_y").getB();
|
||||
}
|
||||
} else {
|
||||
isTransposeA = attributesForNode.get("transpose_a").getB();
|
||||
isTransposeB = attributesForNode.get("transpose_b").getB();
|
||||
}
|
||||
MMulTranspose mMulTranspose = MMulTranspose.builder()
|
||||
.transposeA(isTransposeA).transposeB(isTransposeB)
|
||||
|
|
Loading…
Reference in New Issue