Add op counting to TensorFlowImportValidator (#128)
* Add op counting to TensorFlowImportValidator Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test tweak Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
bd4f77c652
commit
bfd9e3692a
|
@ -570,6 +570,7 @@ public class TestNativeImageLoader {
|
||||||
|
|
||||||
try(InputStream is = new FileInputStream(f)){
|
try(InputStream is = new FileInputStream(f)){
|
||||||
nil.asMatrix(is);
|
nil.asMatrix(is);
|
||||||
|
fail("Expected exception");
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg, msg.contains("decode image"));
|
assertTrue(msg, msg.contains("decode image"));
|
||||||
|
@ -577,6 +578,7 @@ public class TestNativeImageLoader {
|
||||||
|
|
||||||
try(InputStream is = new FileInputStream(f)){
|
try(InputStream is = new FileInputStream(f)){
|
||||||
nil.asImageMatrix(is);
|
nil.asImageMatrix(is);
|
||||||
|
fail("Expected exception");
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg, msg.contains("decode image"));
|
assertTrue(msg, msg.contains("decode image"));
|
||||||
|
@ -584,6 +586,7 @@ public class TestNativeImageLoader {
|
||||||
|
|
||||||
try(InputStream is = new FileInputStream(f)){
|
try(InputStream is = new FileInputStream(f)){
|
||||||
nil.asRowVector(is);
|
nil.asRowVector(is);
|
||||||
|
fail("Expected exception");
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg, msg.contains("decode image"));
|
assertTrue(msg, msg.contains("decode image"));
|
||||||
|
@ -592,6 +595,7 @@ public class TestNativeImageLoader {
|
||||||
try(InputStream is = new FileInputStream(f)){
|
try(InputStream is = new FileInputStream(f)){
|
||||||
INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32);
|
INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32);
|
||||||
nil.asMatrixView(is, arr);
|
nil.asMatrixView(is, arr);
|
||||||
|
fail("Expected exception");
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg, msg.contains("decode image"));
|
assertTrue(msg, msg.contains("decode image"));
|
||||||
|
|
|
@ -38,6 +38,8 @@ public class TFImportStatus {
|
||||||
private final int numUniqueOps;
|
private final int numUniqueOps;
|
||||||
/** The (unique) names of all ops encountered in all graphs */
|
/** The (unique) names of all ops encountered in all graphs */
|
||||||
private final Set<String> opNames;
|
private final Set<String> opNames;
|
||||||
|
/** The number of times each operation was observed in all graphs */
|
||||||
|
private final Map<String,Integer> opCounts;
|
||||||
/** The (unique) names of all ops that were encountered, and can be imported, in all graphs */
|
/** The (unique) names of all ops that were encountered, and can be imported, in all graphs */
|
||||||
private final Set<String> importSupportedOpNames;
|
private final Set<String> importSupportedOpNames;
|
||||||
/** The (unique) names of all ops that were encountered, and can NOT be imported (lacking import mapping) */
|
/** The (unique) names of all ops that were encountered, and can NOT be imported (lacking import mapping) */
|
||||||
|
@ -60,6 +62,11 @@ public class TFImportStatus {
|
||||||
Set<String> newOpNames = new HashSet<>(opNames);
|
Set<String> newOpNames = new HashSet<>(opNames);
|
||||||
newOpNames.addAll(other.opNames);
|
newOpNames.addAll(other.opNames);
|
||||||
|
|
||||||
|
Map<String,Integer> newOpCounts = new HashMap<>(opCounts);
|
||||||
|
for(Map.Entry<String,Integer> e : other.opCounts.entrySet()){
|
||||||
|
newOpCounts.put(e.getKey(), (newOpCounts.containsKey(e.getKey()) ? newOpCounts.get(e.getKey()) : 0) + e.getValue());
|
||||||
|
}
|
||||||
|
|
||||||
Set<String> newImportSupportedOpNames = new HashSet<>(importSupportedOpNames);
|
Set<String> newImportSupportedOpNames = new HashSet<>(importSupportedOpNames);
|
||||||
newImportSupportedOpNames.addAll(other.importSupportedOpNames);
|
newImportSupportedOpNames.addAll(other.importSupportedOpNames);
|
||||||
|
|
||||||
|
@ -89,6 +96,7 @@ public class TFImportStatus {
|
||||||
totalNumOps + other.totalNumOps,
|
totalNumOps + other.totalNumOps,
|
||||||
countUnique,
|
countUnique,
|
||||||
newOpNames,
|
newOpNames,
|
||||||
|
newOpCounts,
|
||||||
newImportSupportedOpNames,
|
newImportSupportedOpNames,
|
||||||
newUnsupportedOpNames,
|
newUnsupportedOpNames,
|
||||||
newUnsupportedOpModels);
|
newUnsupportedOpModels);
|
||||||
|
|
|
@ -230,6 +230,7 @@ public class TensorFlowImportValidator {
|
||||||
try {
|
try {
|
||||||
int opCount = 0;
|
int opCount = 0;
|
||||||
Set<String> opNames = new HashSet<>();
|
Set<String> opNames = new HashSet<>();
|
||||||
|
Map<String,Integer> opCounts = new HashMap<>();
|
||||||
|
|
||||||
try(InputStream bis = new BufferedInputStream(is)) {
|
try(InputStream bis = new BufferedInputStream(is)) {
|
||||||
GraphDef graphDef = GraphDef.parseFrom(bis);
|
GraphDef graphDef = GraphDef.parseFrom(bis);
|
||||||
|
@ -248,6 +249,8 @@ public class TensorFlowImportValidator {
|
||||||
|
|
||||||
String op = nd.getOp();
|
String op = nd.getOp();
|
||||||
opNames.add(op);
|
opNames.add(op);
|
||||||
|
int soFar = opCounts.containsKey(op) ? opCounts.get(op) : 0;
|
||||||
|
opCounts.put(op, soFar + 1);
|
||||||
opCount++;
|
opCount++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -282,6 +285,7 @@ public class TensorFlowImportValidator {
|
||||||
opCount,
|
opCount,
|
||||||
opNames.size(),
|
opNames.size(),
|
||||||
opNames,
|
opNames,
|
||||||
|
opCounts,
|
||||||
importSupportedOpNames,
|
importSupportedOpNames,
|
||||||
unsupportedOpNames,
|
unsupportedOpNames,
|
||||||
unsupportedOpModel);
|
unsupportedOpModel);
|
||||||
|
@ -297,6 +301,7 @@ public class TensorFlowImportValidator {
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
Collections.<String>emptySet(),
|
Collections.<String>emptySet(),
|
||||||
|
Collections.<String, Integer>emptyMap(),
|
||||||
Collections.<String>emptySet(),
|
Collections.<String>emptySet(),
|
||||||
Collections.<String>emptySet(),
|
Collections.<String>emptySet(),
|
||||||
Collections.<String, Set<String>>emptyMap());
|
Collections.<String, Set<String>>emptyMap());
|
||||||
|
|
Loading…
Reference in New Issue