From bfd9e3692a5a7e67772ed9aa4541857faf83ab2f Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 17 Dec 2019 10:23:37 +1100 Subject: [PATCH] Add op counting to TensorFlowImportValidator (#128) * Add op counting to TensorFlowImportValidator Signed-off-by: AlexDBlack * Test tweak Signed-off-by: AlexDBlack --- .../org/datavec/image/loader/TestNativeImageLoader.java | 4 ++++ .../java/org/nd4j/imports/tensorflow/TFImportStatus.java | 8 ++++++++ .../imports/tensorflow/TensorFlowImportValidator.java | 5 +++++ 3 files changed, 17 insertions(+) diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java index 5f634bab8..b8fa0c43d 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java @@ -570,6 +570,7 @@ public class TestNativeImageLoader { try(InputStream is = new FileInputStream(f)){ nil.asMatrix(is); + fail("Expected exception"); } catch (IOException e){ String msg = e.getMessage(); assertTrue(msg, msg.contains("decode image")); @@ -577,6 +578,7 @@ public class TestNativeImageLoader { try(InputStream is = new FileInputStream(f)){ nil.asImageMatrix(is); + fail("Expected exception"); } catch (IOException e){ String msg = e.getMessage(); assertTrue(msg, msg.contains("decode image")); @@ -584,6 +586,7 @@ public class TestNativeImageLoader { try(InputStream is = new FileInputStream(f)){ nil.asRowVector(is); + fail("Expected exception"); } catch (IOException e){ String msg = e.getMessage(); assertTrue(msg, msg.contains("decode image")); @@ -592,6 +595,7 @@ public class TestNativeImageLoader { try(InputStream is = new FileInputStream(f)){ INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32); nil.asMatrixView(is, arr); + fail("Expected exception"); } catch (IOException e){ String msg = e.getMessage(); assertTrue(msg, msg.contains("decode image")); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TFImportStatus.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TFImportStatus.java index 9d5ec8098..332901c4b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TFImportStatus.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TFImportStatus.java @@ -38,6 +38,8 @@ public class TFImportStatus { private final int numUniqueOps; /** The (unique) names of all ops encountered in all graphs */ private final Set opNames; + /** The number of times each operation was observed in all graphs */ + private final Map opCounts; /** The (unique) names of all ops that were encountered, and can be imported, in all graphs */ private final Set importSupportedOpNames; /** The (unique) names of all ops that were encountered, and can NOT be imported (lacking import mapping) */ @@ -60,6 +62,11 @@ public class TFImportStatus { Set newOpNames = new HashSet<>(opNames); newOpNames.addAll(other.opNames); + Map newOpCounts = new HashMap<>(opCounts); + for(Map.Entry e : other.opCounts.entrySet()){ + newOpCounts.put(e.getKey(), (newOpCounts.containsKey(e.getKey()) ? newOpCounts.get(e.getKey()) : 0) + e.getValue()); + } + Set newImportSupportedOpNames = new HashSet<>(importSupportedOpNames); newImportSupportedOpNames.addAll(other.importSupportedOpNames); @@ -89,6 +96,7 @@ public class TFImportStatus { totalNumOps + other.totalNumOps, countUnique, newOpNames, + newOpCounts, newImportSupportedOpNames, newUnsupportedOpNames, newUnsupportedOpModels); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java index 39d8e1577..0493099f2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java @@ -230,6 +230,7 @@ public class TensorFlowImportValidator { try { int opCount = 0; Set opNames = new HashSet<>(); + Map opCounts = new HashMap<>(); try(InputStream bis = new BufferedInputStream(is)) { GraphDef graphDef = GraphDef.parseFrom(bis); @@ -248,6 +249,8 @@ public class TensorFlowImportValidator { String op = nd.getOp(); opNames.add(op); + int soFar = opCounts.containsKey(op) ? opCounts.get(op) : 0; + opCounts.put(op, soFar + 1); opCount++; } } @@ -282,6 +285,7 @@ public class TensorFlowImportValidator { opCount, opNames.size(), opNames, + opCounts, importSupportedOpNames, unsupportedOpNames, unsupportedOpModel); @@ -297,6 +301,7 @@ public class TensorFlowImportValidator { 0, 0, Collections.emptySet(), + Collections.emptyMap(), Collections.emptySet(), Collections.emptySet(), Collections.>emptyMap());