From 35ab4a72ba44079257bec3254175806f5b0a03e9 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 30 Nov 2019 18:58:37 +1100 Subject: [PATCH] TF import test resources loading precision fixes (#92) * Fix precision issues when loading from CSV Signed-off-by: AlexDBlack * Small tweak Signed-off-by: AlexDBlack --- .../imports/graphmapper/tf/TFGraphMapper.java | 8 ++ .../TFGraphs/TFGraphTestAllHelper.java | 127 +++++++++++++++--- 2 files changed, 118 insertions(+), 17 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java index 8605467cc..f54b532e8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java @@ -55,6 +55,14 @@ import java.util.*; @Slf4j public class TFGraphMapper { + /** + * @deprecated Use static methods - {@link #importGraph(File)} etc + */ + @Deprecated + public static TFGraphMapper getInstance(){ + return new TFGraphMapper(); + } + /** * Import a frozen TensorFlow protobuf (.pb) file from the specified file * diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index 6582d38db..eae14b230 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -594,7 +594,7 @@ public class TFGraphTestAllHelper { val key = modelDir + "/" + okey; // parse type directly - val value = ArrayOptionsHelper.dataType(split[1]); + DataType value = ArrayOptionsHelper.dataType(split[1]); // adding key directly //if (dtypes.containsKey(key)) @@ -672,12 +672,35 @@ public class TFGraphTestAllHelper { INDArray varValue; if(filtered.size() == 0){ //Scalar - float[] varContents; - try(InputStream is = new BufferedInputStream(resources.get(i).getSecond().getInputStream())){ - varContents = Nd4j.readNumpy(is, ",").data().asFloat(); + String content = IOUtils.toString(resources.get(i).getSecond().getInputStream(), StandardCharsets.UTF_8); + switch (type){ + case DOUBLE: + case FLOAT: + case HALF: + case BFLOAT16: + varValue = Nd4j.scalar(type, parseDouble(content)); + break; + case LONG: + case INT: + case SHORT: + case UBYTE: + case BYTE: + case UINT16: + case UINT32: + case UINT64: + varValue = Nd4j.scalar(type, parseLong(content)); + break; + case BOOL: + varValue = Nd4j.scalar(parseBoolean(content)); + break; + case UTF8: + varValue = Nd4j.scalar(content); + break; + case COMPRESSED: + case UNKNOWN: + default: + throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type); } - Preconditions.checkState(varContents.length == 1, "Expected length 1 content for scalar shape; got length %s", varContents.length); - varValue = Nd4j.scalar(type, varContents[0]); } else { int[] varShape = new int[filtered.size()]; for( int j=0; j testPrecisionOverride(String testName){ if("conv_4".equalsIgnoreCase(testName)){