Add op counting to TensorFlowImportValidator (#128)

* Add op counting to TensorFlowImportValidator

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Test tweak

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-12-17 10:23:37 +11:00 committed by GitHub
parent bd4f77c652
commit bfd9e3692a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 0 deletions

View File

@ -570,6 +570,7 @@ public class TestNativeImageLoader {
try(InputStream is = new FileInputStream(f)){ try(InputStream is = new FileInputStream(f)){
nil.asMatrix(is); nil.asMatrix(is);
fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image")); assertTrue(msg, msg.contains("decode image"));
@ -577,6 +578,7 @@ public class TestNativeImageLoader {
try(InputStream is = new FileInputStream(f)){ try(InputStream is = new FileInputStream(f)){
nil.asImageMatrix(is); nil.asImageMatrix(is);
fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image")); assertTrue(msg, msg.contains("decode image"));
@ -584,6 +586,7 @@ public class TestNativeImageLoader {
try(InputStream is = new FileInputStream(f)){ try(InputStream is = new FileInputStream(f)){
nil.asRowVector(is); nil.asRowVector(is);
fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image")); assertTrue(msg, msg.contains("decode image"));
@ -592,6 +595,7 @@ public class TestNativeImageLoader {
try(InputStream is = new FileInputStream(f)){ try(InputStream is = new FileInputStream(f)){
INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32); INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32);
nil.asMatrixView(is, arr); nil.asMatrixView(is, arr);
fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image")); assertTrue(msg, msg.contains("decode image"));

View File

@ -38,6 +38,8 @@ public class TFImportStatus {
private final int numUniqueOps; private final int numUniqueOps;
/** The (unique) names of all ops encountered in all graphs */ /** The (unique) names of all ops encountered in all graphs */
private final Set<String> opNames; private final Set<String> opNames;
/** The number of times each operation was observed in all graphs */
private final Map<String,Integer> opCounts;
/** The (unique) names of all ops that were encountered, and can be imported, in all graphs */ /** The (unique) names of all ops that were encountered, and can be imported, in all graphs */
private final Set<String> importSupportedOpNames; private final Set<String> importSupportedOpNames;
/** The (unique) names of all ops that were encountered, and can NOT be imported (lacking import mapping) */ /** The (unique) names of all ops that were encountered, and can NOT be imported (lacking import mapping) */
@ -60,6 +62,11 @@ public class TFImportStatus {
Set<String> newOpNames = new HashSet<>(opNames); Set<String> newOpNames = new HashSet<>(opNames);
newOpNames.addAll(other.opNames); newOpNames.addAll(other.opNames);
Map<String,Integer> newOpCounts = new HashMap<>(opCounts);
for(Map.Entry<String,Integer> e : other.opCounts.entrySet()){
newOpCounts.put(e.getKey(), (newOpCounts.containsKey(e.getKey()) ? newOpCounts.get(e.getKey()) : 0) + e.getValue());
}
Set<String> newImportSupportedOpNames = new HashSet<>(importSupportedOpNames); Set<String> newImportSupportedOpNames = new HashSet<>(importSupportedOpNames);
newImportSupportedOpNames.addAll(other.importSupportedOpNames); newImportSupportedOpNames.addAll(other.importSupportedOpNames);
@ -89,6 +96,7 @@ public class TFImportStatus {
totalNumOps + other.totalNumOps, totalNumOps + other.totalNumOps,
countUnique, countUnique,
newOpNames, newOpNames,
newOpCounts,
newImportSupportedOpNames, newImportSupportedOpNames,
newUnsupportedOpNames, newUnsupportedOpNames,
newUnsupportedOpModels); newUnsupportedOpModels);

View File

@ -230,6 +230,7 @@ public class TensorFlowImportValidator {
try { try {
int opCount = 0; int opCount = 0;
Set<String> opNames = new HashSet<>(); Set<String> opNames = new HashSet<>();
Map<String,Integer> opCounts = new HashMap<>();
try(InputStream bis = new BufferedInputStream(is)) { try(InputStream bis = new BufferedInputStream(is)) {
GraphDef graphDef = GraphDef.parseFrom(bis); GraphDef graphDef = GraphDef.parseFrom(bis);
@ -248,6 +249,8 @@ public class TensorFlowImportValidator {
String op = nd.getOp(); String op = nd.getOp();
opNames.add(op); opNames.add(op);
int soFar = opCounts.containsKey(op) ? opCounts.get(op) : 0;
opCounts.put(op, soFar + 1);
opCount++; opCount++;
} }
} }
@ -282,6 +285,7 @@ public class TensorFlowImportValidator {
opCount, opCount,
opNames.size(), opNames.size(),
opNames, opNames,
opCounts,
importSupportedOpNames, importSupportedOpNames,
unsupportedOpNames, unsupportedOpNames,
unsupportedOpModel); unsupportedOpModel);
@ -297,6 +301,7 @@ public class TensorFlowImportValidator {
0, 0,
0, 0,
Collections.<String>emptySet(), Collections.<String>emptySet(),
Collections.<String, Integer>emptyMap(),
Collections.<String>emptySet(), Collections.<String>emptySet(),
Collections.<String>emptySet(), Collections.<String>emptySet(),
Collections.<String, Set<String>>emptyMap()); Collections.<String, Set<String>>emptyMap());