From 8bc3172e407bdc51f3a62dd3b72b503e9a8fd3e9 Mon Sep 17 00:00:00 2001 From: agibsonccc Date: Tue, 16 Feb 2021 11:02:27 +0900 Subject: [PATCH] Fix NCHW case for fused batch norm --- libnd4j/include/array/NDArray.hXX | 22 +-- .../declarable/generic/nn/fusedBatchNorm.cpp | 28 ++- .../TFGraphs/TFGraphTestAllHelper.java | 2 +- .../TFGraphs/TFGraphTestAllSameDiff.java | 21 +-- .../TFGraphs/ValidateZooModelPredictions.java | 4 + .../definitions/TensorflowOpDeclarations.kt | 1 - .../tensorflow-mapping-ruleset.pbtxt | 12 -- .../tensorflow/TestTensorflowIR.kt | 161 ++++++++++++++++-- .../tensorflow-processes.pbtxt | 12 -- 9 files changed, 193 insertions(+), 70 deletions(-) diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/NDArray.hXX index 530cf50c0..acfdc9e4d 100644 --- a/libnd4j/include/array/NDArray.hXX +++ b/libnd4j/include/array/NDArray.hXX @@ -1578,20 +1578,20 @@ namespace sd { int rank = shape::rank(_shapeInfo); int lim = shape::shapeInfoLength(rank); - if(msg != nullptr) - printf("shapeInfo %s: [", msg); - else - printf("shapeInfo: ["); - - printf("%i, ", rank); + if(msg != nullptr) { + nd4j_printf("shapeInfo %s: [", msg); + } else { + nd4j_printf("shapeInfo: [%s", ""); + } + nd4j_printf("%i, ", rank); for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){ if(i == rank + 1) - printf(" "); - printf("%lld,", _shapeInfo[i]); + nd4j_printf(" ",""); + nd4j_printf("%lld,", _shapeInfo[i]); } - printf(" %lld,", shape::type(_shapeInfo)); - printf("%lld,", shape::elementWiseStride(_shapeInfo)); - printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo)); + nd4j_printf(" %lld,", shape::type(_shapeInfo)); + nd4j_printf("%lld,", shape::elementWiseStride(_shapeInfo)); + nd4j_printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo)); fflush(stdout); } diff --git a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp index 0a7525a35..dc957bee0 100644 --- a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp @@ -45,6 +45,7 @@ namespace sd { const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW const bool isTraining = (bool)INT_ARG(1); + nd4j_debug("CUSTOM_OP fused_batch_norm: data format, is NCHW: %d, isTraining: %d\n",dataFormat,isTraining); REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf()); @@ -62,8 +63,19 @@ namespace sd { } auto xCast = x->cast(sd::DataType::FLOAT32); - - + //move to NWHC + /** + * TODO: TF has a permute to NWHC here: + * https://github.com/tensorflow/tensorflow/blob/ce34a83e03394492b1c4e5bb92fbd56da2ba7ce5/tensorflow/core/kernels/fused_batch_norm_op.cc#L137 + * + * This should be done as well for us, but results are still off. + * Figure out differences. + */ + if(dataFormat) { + xCast.printShapeInfo("x cast shape info pre permute"); + xCast = xCast.permute({0, 2, 3, 1}); + xCast.printShapeInfo("x cast shape info post permute"); + } REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str()); REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str()); @@ -126,14 +138,22 @@ namespace sd { auto scaledVariance = ((*variance + epsilon).transform(transform::RSqrt) * (*scale)).cast(xAffected.dataType()); auto xScaled1 = xCentered * scaledVariance; auto xShifted1 = xScaled1 + *offset; + if(dataFormat) { + //need to reshape from matrix to 4d then permute the ordering due to NWHC ordering + auto reshaped = xShifted1.reshape(xCast.ordering(),xCast.getShapeAsVector()); + reshaped.permutei({0,3,1,2}); + y->assign(reshaped); + + } + else //NWHC case + y->assign(xShifted1); - y->assign(xShifted1); if(isTraining) { delete mean; delete variance; } - + return Status::OK(); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index f99d92057..4738083d9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -88,6 +88,7 @@ import static org.nd4j.imports.tfgraphs.TFGraphsSkipNodes.skipNode; @Slf4j public class TFGraphTestAllHelper { public static final String resourceFolderVar = "DL4J_TEST_RESOURCES"; + public static TensorflowFrameworkImporter tensorflowFrameworkImporter = new TensorflowFrameworkImporter(); public enum ExecuteWith { SAMEDIFF, LIBND4J, JUST_PRINT @@ -103,7 +104,6 @@ public class TFGraphTestAllHelper { e.printStackTrace(); } - TensorflowFrameworkImporter tensorflowFrameworkImporter = new TensorflowFrameworkImporter(); return tensorflowFrameworkImporter.runImport(file.getAbsolutePath(),Collections.emptyMap()); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 06817a6c9..438a6f797 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -76,30 +76,15 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a "layers_dropout/rank3_d05_train_mask1", "layers_dropout/rank2_d09_train", "layers_dropout/rank2_d05_train",*/ - /* "primitive_gru_dynamic", - "layers_dropout/rank4_d05_train", - "fused_batch_norm/float16_nhwc", - "rnn/lstmblockcell/dynamic_b1_n5-3_ts4_noPH_noClip_fB1_noIS_withTM", - "rnn/lstmcell/dynamic_b1_nIn5_nOut3_ts4_noPH_noClip_fB1_Tanh_noIS_float_withTM", - "rnn/grublockcellv2/dynamic_b1_n3-2_ts1_noIS_noTM"*/ - /* "unsorted_segment/unsorted_segment_mean_rank3", - "unsorted_segment/unsorted_segment_sqrt_n_rank2", - "unsorted_segment/unsorted_segment_mean_rank2", - "unsorted_segment/unsorted_segment_mean_rank3", - "unsorted_segment/unsorted_segment_sum_rank3", - "unsorted_segment/unsorted_segment_min_rank2", - "unsorted_segment/unsorted_segment_prod_rank2", - "unsorted_segment/unsorted_segment_max_rank2",*/ - "bincount/rank0_weights", - "bincount/rank2_weights" - /* "compare_and_bitpack/bool", + + "compare_and_bitpack/bool", "compare_and_bitpack/float32", "compare_and_bitpack/float64", "compare_and_bitpack/half", "compare_and_bitpack/int32", "compare_and_bitpack/int8", "compare_and_bitpack/int64", - "compare_and_bitpack/int16"*/ + "compare_and_bitpack/int16" diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java index 853b94ab3..019d36010 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java @@ -64,6 +64,10 @@ public class ValidateZooModelPredictions extends BaseNd4jTest { Nd4j.getRandom().setSeed(123); } + @Override + public long getTimeoutMilliseconds() { + return Long.MAX_VALUE; + } @Test public void testMobilenetV1() throws Exception { diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt b/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt index 117cd9faa..feea382e1 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt @@ -489,7 +489,6 @@ val compareAndBitPack = TensorflowMappingProcess( opName = "compare_and_bitpack", opMappingRegistry = tensorflowOpRegistry, inputFrameworkOpName = "CompareAndBitpack", - attributeMappingRules = listOf(valueMapping(mutableMapOf("dtype" to "T"))), tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input","y" to "threshold"))) ) diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt b/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt index e6cf75dee..5183dad4a 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt @@ -9137,18 +9137,6 @@ mappings { ruleType: "tensor" inputFrameworkOpName: "CompareAndBitpack" } - rule { - ruleName: "valuemapping" - functionName: "valuemapping" - inputDataTypeName: "T" - outputDataTypeName: "dtype" - inputToOutput { - key: "dtype" - value: "T" - } - ruleType: "attribute" - inputFrameworkOpName: "CompareAndBitpack" - } } mappings { frameworkName: "tensorflow" diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowIR.kt b/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowIR.kt index e7666d79f..409b611cb 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowIR.kt +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowIR.kt @@ -24,38 +24,32 @@ import junit.framework.Assert.assertEquals import junit.framework.Assert.assertTrue import org.apache.commons.io.FileUtils import org.apache.commons.io.IOUtils +import org.junit.Assert import org.junit.Ignore import org.junit.jupiter.api.Test -import org.nd4j.autodiff.samediff.SameDiff import org.nd4j.common.io.ClassPathResource import org.nd4j.imports.graphmapper.tf.TFGraphMapper import org.nd4j.ir.OpNamespace import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.api.ops.DynamicCustomOp -import org.nd4j.linalg.api.ops.custom.Roll -import org.nd4j.linalg.api.ops.impl.transforms.BinCount -import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.profiler.ProfilerConfig import org.nd4j.samediff.frameworkimport.ImportGraph import org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoaderHolder -import org.nd4j.samediff.frameworkimport.registry.OpMappingRegistry import org.nd4j.samediff.frameworkimport.registry.OpRegistryHolder import org.nd4j.samediff.frameworkimport.tensorflow.context.TensorflowMappingContext import org.nd4j.samediff.frameworkimport.tensorflow.definitions.registry import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFrameworkImporter import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner -import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRNode -import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRTensor -import org.nd4j.shade.protobuf.ByteString import org.nd4j.shade.protobuf.TextFormat -import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner import org.tensorflow.framework.* import java.io.File -import java.lang.IllegalStateException import java.nio.charset.Charset -import kotlin.math.max +import java.nio.charset.StandardCharsets +import java.util.* +import kotlin.collections.HashMap +import kotlin.collections.HashSet data class GraphInput(val graphDef: GraphDef,val inputNames: List,val outputNames: List, val inputArrays: Map,val dynamicArrays: Map) @@ -78,6 +72,7 @@ class TestTensorflowIR { fun manualTest() { val manualGraph = FileUtils.readFileToString(File("test.pbtxt"),Charset.defaultCharset()) val parsedGraph = GraphDef.newBuilder() + //C:\Users\agibs\.nd4jtests\resnetv2_imagenet_frozen_graph TextFormat.merge(manualGraph,parsedGraph) val textGraph = parsedGraph.build() println(textGraph) @@ -155,6 +150,150 @@ class TestTensorflowIR { //assertEquals(tfOutput,output) } + + @Test + @Ignore + fun manualTestBinary() { + val path = "C:\\Users\\agibs\\.nd4jtests\\resnetv2_imagenet_frozen_graph\\resnetv2_imagenet_frozen_graph.pb" + val bytes = FileUtils.readFileToByteArray(File(path)) + val parsedGraph = GraphDef.parseFrom(bytes) + val tfImporter = TensorflowFrameworkImporter() + //with names [image] and shapes {image=[4, 2, 28, 28, 3]} + Nd4j.getEnvironment().isDebug = true + Nd4j.getEnvironment().isVerbose = true + //TFGraphMapper.importGraph(textGraph) + // val inputMap = mapOf("input_1" to Nd4j.zeros(10).castTo(org.nd4j.linalg.api.buffer.DataType.INT32),"input_2" to Nd4j.zeros(1,8).castTo(org.nd4j.linalg.api.buffer.DataType.DOUBLE)) + //val inputMap = mapOf("image" to Nd4j.ones(1,128,128,4)) + /** + * TODO: fix emptyReduce. + * When we pass in 2 inputs where input 1 is the dimensions, the results + * work. In our model import, it appears that + * the empty dimensions aren't being passed down + * for int arguments properly. + * We need to figure out the difference between specifying 2 input arrays + * and ints, that or we need to make it so that arrays can be passed in + * for dimensions for each singular reduce op. + * + * Each op seems to be able to take in dimensions for indices. + * It *MIGHT* be better just to pass in dimensions directly. + */ + + + //Load data + //Because we don't have DataVec NativeImageLoader in ND4J tests due to circular dependencies, we'll load the image previously saved... + var imgFile = + File("goldenretriever_rgb224_unnormalized_nchw_INDArray.bin") + var img = Nd4j.readBinary(imgFile).castTo(org.nd4j.linalg.api.buffer.DataType.FLOAT) + img = img.permute(0, 2, 3, 1).dup() //to NHWC + + + //Perform inference + + //Resnet v2 - NO external normalization, just resize and center crop + // https://github.com/tensorflow/models/blob/d32d957a02f5cffb745a4da0d78f8432e2c52fd4/research/tensorrt/tensorrt.py#L70 + // https://github.com/tensorflow/models/blob/1af55e018eebce03fb61bba9959a04672536107d/official/resnet/imagenet_preprocessing.py#L253-L256 + + val importedGraph = TFGraphMapper.importGraph(parsedGraph) + + //Load labels + val labels = labels() + + + //Perform inference + val inputs: List = importedGraph.inputs() + Assert.assertEquals(1, inputs.size.toLong()) + + val out = "softmax_tensor" + val m: Map = importedGraph.output(Collections.singletonMap(inputs[0], img), out) + + val outArr = m[out] + + + println("SHAPE: " + Arrays.toString(outArr!!.shape())) + println(outArr) + + val argmax = outArr!!.argMax(1) + + //Load labels + + val classIdx = argmax.getInt(0) + val className = labels[classIdx] + val expClass = "golden retriever" + val prob = outArr!!.getDouble(classIdx.toLong()) + + println("Predicted class: $classIdx - \"$className\" - probability = $prob") + Assert.assertEquals(expClass, className) + + val inputMap = Collections.singletonMap(inputs[0], img) + val tensorflowIRGraph = TensorflowIRGraph(parsedGraph,tensorflowOps,tfImporter.registry) + val outputList = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }.toMutableSet() + val tfGraphRunner = TensorflowIRGraphRunner(tensorflowIRGraph, inputMap.keys.toList(),listOf("batch_normalization/FusedBatchNorm",out)) + val graph = tfImporter.importFromGraph(parsedGraph,inputMap) + val tfOutput = tfGraphRunner.run(inputMap) + + /** + * TODO: UnsortedSegmentSum ,Solution is almost there, need to figure out how to + * output correct shape. + * + * Shape in TF is 5 x 5 but actual real output seems to be 1 x 10. + * We need to change the output shape to work like TF does. + */ + val output2 = importedGraph.outputAll(inputMap) + val output = graph.outputAll(inputMap) + + + //assertEquals(tfOutput.keys,outputList) + //assertEquals(tfOutput.keys,output2.keys) + val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() } + val skipValidation = setOf("parallel_stack/ExpandDims/dim") + //assertEquals(output.keys,output2.keys) + val notEquals = LinkedHashSet() + val notEqualsTf = LinkedHashSet() + val notEqualsOp = LinkedHashSet() + names.forEach { + val value = output[it] + val value2 = output2[it] + val tfValue = tfOutput[it] + if(value!! != (value2!!)) { + val oldOps = importedGraph.ops[it] + val newOps = graph.ops[it] + val oldVar = importedGraph.variables[it] + val newVar = graph.variables[it] + notEquals.add(it) + } + + if(tfValue != null && tfValue!! != (value!!)) { + val oldOps = importedGraph.ops[it] + val newOps = graph.ops[it] + val oldVar = importedGraph.variables[it] + val newVar = graph.variables[it] + notEqualsTf.add(it) + } + + val oldOp = importedGraph.ops[it] + val newOp = graph.ops[it] + if(oldOp != newOp) { + notEqualsOp.add(it) + } + + } + + println(notEquals) + println(notEqualsTf) + println("Not equals ops $notEqualsOp") + println() + // assertEquals(output,output2) + //assertEquals(tfOutput,output) + } + + @Throws(Exception::class) + fun labels(): List { + val labelsFile = + File("imagenet_labellist.txt") + return FileUtils.readLines(labelsFile, StandardCharsets.UTF_8) + } + + @Test @Ignore fun manualTest2() { diff --git a/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt b/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt index e6cf75dee..5183dad4a 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt +++ b/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt @@ -9137,18 +9137,6 @@ mappings { ruleType: "tensor" inputFrameworkOpName: "CompareAndBitpack" } - rule { - ruleName: "valuemapping" - functionName: "valuemapping" - inputDataTypeName: "T" - outputDataTypeName: "dtype" - inputToOutput { - key: "dtype" - value: "T" - } - ruleType: "attribute" - inputFrameworkOpName: "CompareAndBitpack" - } } mappings { frameworkName: "tensorflow"