Add op counting to TensorFlowImportValidator (#128)
* Add op counting to TensorFlowImportValidator Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test tweak Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
bd4f77c652
commit
bfd9e3692a
|
@ -570,6 +570,7 @@ public class TestNativeImageLoader {
|
|||
|
||||
try(InputStream is = new FileInputStream(f)){
|
||||
nil.asMatrix(is);
|
||||
fail("Expected exception");
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
|
@ -577,6 +578,7 @@ public class TestNativeImageLoader {
|
|||
|
||||
try(InputStream is = new FileInputStream(f)){
|
||||
nil.asImageMatrix(is);
|
||||
fail("Expected exception");
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
|
@ -584,6 +586,7 @@ public class TestNativeImageLoader {
|
|||
|
||||
try(InputStream is = new FileInputStream(f)){
|
||||
nil.asRowVector(is);
|
||||
fail("Expected exception");
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
|
@ -592,6 +595,7 @@ public class TestNativeImageLoader {
|
|||
try(InputStream is = new FileInputStream(f)){
|
||||
INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32);
|
||||
nil.asMatrixView(is, arr);
|
||||
fail("Expected exception");
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
|
|
|
@ -38,6 +38,8 @@ public class TFImportStatus {
|
|||
private final int numUniqueOps;
|
||||
/** The (unique) names of all ops encountered in all graphs */
|
||||
private final Set<String> opNames;
|
||||
/** The number of times each operation was observed in all graphs */
|
||||
private final Map<String,Integer> opCounts;
|
||||
/** The (unique) names of all ops that were encountered, and can be imported, in all graphs */
|
||||
private final Set<String> importSupportedOpNames;
|
||||
/** The (unique) names of all ops that were encountered, and can NOT be imported (lacking import mapping) */
|
||||
|
@ -60,6 +62,11 @@ public class TFImportStatus {
|
|||
Set<String> newOpNames = new HashSet<>(opNames);
|
||||
newOpNames.addAll(other.opNames);
|
||||
|
||||
Map<String,Integer> newOpCounts = new HashMap<>(opCounts);
|
||||
for(Map.Entry<String,Integer> e : other.opCounts.entrySet()){
|
||||
newOpCounts.put(e.getKey(), (newOpCounts.containsKey(e.getKey()) ? newOpCounts.get(e.getKey()) : 0) + e.getValue());
|
||||
}
|
||||
|
||||
Set<String> newImportSupportedOpNames = new HashSet<>(importSupportedOpNames);
|
||||
newImportSupportedOpNames.addAll(other.importSupportedOpNames);
|
||||
|
||||
|
@ -89,6 +96,7 @@ public class TFImportStatus {
|
|||
totalNumOps + other.totalNumOps,
|
||||
countUnique,
|
||||
newOpNames,
|
||||
newOpCounts,
|
||||
newImportSupportedOpNames,
|
||||
newUnsupportedOpNames,
|
||||
newUnsupportedOpModels);
|
||||
|
|
|
@ -230,6 +230,7 @@ public class TensorFlowImportValidator {
|
|||
try {
|
||||
int opCount = 0;
|
||||
Set<String> opNames = new HashSet<>();
|
||||
Map<String,Integer> opCounts = new HashMap<>();
|
||||
|
||||
try(InputStream bis = new BufferedInputStream(is)) {
|
||||
GraphDef graphDef = GraphDef.parseFrom(bis);
|
||||
|
@ -248,6 +249,8 @@ public class TensorFlowImportValidator {
|
|||
|
||||
String op = nd.getOp();
|
||||
opNames.add(op);
|
||||
int soFar = opCounts.containsKey(op) ? opCounts.get(op) : 0;
|
||||
opCounts.put(op, soFar + 1);
|
||||
opCount++;
|
||||
}
|
||||
}
|
||||
|
@ -282,6 +285,7 @@ public class TensorFlowImportValidator {
|
|||
opCount,
|
||||
opNames.size(),
|
||||
opNames,
|
||||
opCounts,
|
||||
importSupportedOpNames,
|
||||
unsupportedOpNames,
|
||||
unsupportedOpModel);
|
||||
|
@ -297,6 +301,7 @@ public class TensorFlowImportValidator {
|
|||
0,
|
||||
0,
|
||||
Collections.<String>emptySet(),
|
||||
Collections.<String, Integer>emptyMap(),
|
||||
Collections.<String>emptySet(),
|
||||
Collections.<String>emptySet(),
|
||||
Collections.<String, Set<String>>emptyMap());
|
||||
|
|
Loading…
Reference in New Issue