parent
d854e28b34
commit
8b4b8977a0
|
@ -945,7 +945,8 @@ public class TFGraphMapper extends BaseGraphMapper<GraphDef,NodeDef,AttrValue,No
|
||||||
} else if(tensorProto.getOp().equals("Assert")){
|
} else if(tensorProto.getOp().equals("Assert")){
|
||||||
return org.nd4j.linalg.api.buffer.DataType.BOOL;
|
return org.nd4j.linalg.api.buffer.DataType.BOOL;
|
||||||
}
|
}
|
||||||
log.warn("No TensorFlow descriptor found for tensor \"{}\", op \"{}\"", tensorProto.getName(), tensorProto.getOp());
|
//Not in ops.proto
|
||||||
|
log.debug("No TensorFlow descriptor found for tensor \"{}\", op \"{}\"", tensorProto.getName(), tensorProto.getOp());
|
||||||
|
|
||||||
//No descriptor... try to fall back on common type attribute names
|
//No descriptor... try to fall back on common type attribute names
|
||||||
if(!tensorProto.containsAttr("dtype") && !tensorProto.containsAttr("Tidx") && !tensorProto.containsAttr("T"))
|
if(!tensorProto.containsAttr("dtype") && !tensorProto.containsAttr("Tidx") && !tensorProto.containsAttr("T"))
|
||||||
|
|
|
@ -153,7 +153,7 @@ public class Mmul extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String[] tensorflowNames() {
|
public String[] tensorflowNames() {
|
||||||
return new String[]{"MatMul", "BatchMatMul"};
|
return new String[]{"MatMul", "BatchMatMul", "BatchMatMulV2"};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -171,7 +171,12 @@ public class Mmul extends DynamicCustomOp {
|
||||||
|
|
||||||
boolean isTransposeA;
|
boolean isTransposeA;
|
||||||
boolean isTransposeB;
|
boolean isTransposeB;
|
||||||
if(nodeDef.getOp().equalsIgnoreCase("BatchMatMul")){
|
if(nodeDef.getOp().equalsIgnoreCase("MatMul")){
|
||||||
|
isTransposeA = attributesForNode.get("transpose_a").getB();
|
||||||
|
isTransposeB = attributesForNode.get("transpose_b").getB();
|
||||||
|
|
||||||
|
} else {
|
||||||
|
//BatchMatMul, BatchMatMulV2
|
||||||
//In practice, BatchMatMul seems to use "adj_x" and "adj_y" instead of "transpose_a" and "transpose_b"
|
//In practice, BatchMatMul seems to use "adj_x" and "adj_y" instead of "transpose_a" and "transpose_b"
|
||||||
if(attributesForNode.containsKey("transpose_a")){
|
if(attributesForNode.containsKey("transpose_a")){
|
||||||
isTransposeA = attributesForNode.get("transpose_a").getB();
|
isTransposeA = attributesForNode.get("transpose_a").getB();
|
||||||
|
@ -183,9 +188,6 @@ public class Mmul extends DynamicCustomOp {
|
||||||
} else {
|
} else {
|
||||||
isTransposeB = attributesForNode.get("adj_y").getB();
|
isTransposeB = attributesForNode.get("adj_y").getB();
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
isTransposeA = attributesForNode.get("transpose_a").getB();
|
|
||||||
isTransposeB = attributesForNode.get("transpose_b").getB();
|
|
||||||
}
|
}
|
||||||
MMulTranspose mMulTranspose = MMulTranspose.builder()
|
MMulTranspose mMulTranspose = MMulTranspose.builder()
|
||||||
.transposeA(isTransposeA).transposeB(isTransposeB)
|
.transposeA(isTransposeA).transposeB(isTransposeB)
|
||||||
|
|
Loading…
Reference in New Issue