Fix NCHW case for fused batch norm

master
agibsonccc 2021-02-16 11:02:27 +09:00
parent e88d0fe96c
commit 8bc3172e40
9 changed files with 193 additions and 70 deletions

View File

@ -1578,20 +1578,20 @@ namespace sd {
int rank = shape::rank(_shapeInfo); int rank = shape::rank(_shapeInfo);
int lim = shape::shapeInfoLength(rank); int lim = shape::shapeInfoLength(rank);
if(msg != nullptr) if(msg != nullptr) {
printf("shapeInfo %s: [", msg); nd4j_printf("shapeInfo %s: [", msg);
else } else {
printf("shapeInfo: ["); nd4j_printf("shapeInfo: [%s", "");
}
printf("%i, ", rank); nd4j_printf("%i, ", rank);
for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){ for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){
if(i == rank + 1) if(i == rank + 1)
printf(" "); nd4j_printf(" ","");
printf("%lld,", _shapeInfo[i]); nd4j_printf("%lld,", _shapeInfo[i]);
} }
printf(" %lld,", shape::type(_shapeInfo)); nd4j_printf(" %lld,", shape::type(_shapeInfo));
printf("%lld,", shape::elementWiseStride(_shapeInfo)); nd4j_printf("%lld,", shape::elementWiseStride(_shapeInfo));
printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo)); nd4j_printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
fflush(stdout); fflush(stdout);
} }

View File

@ -45,6 +45,7 @@ namespace sd {
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
const bool isTraining = (bool)INT_ARG(1); const bool isTraining = (bool)INT_ARG(1);
nd4j_debug("CUSTOM_OP fused_batch_norm: data format, is NCHW: %d, isTraining: %d\n",dataFormat,isTraining);
REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf()); REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf());
@ -62,8 +63,19 @@ namespace sd {
} }
auto xCast = x->cast(sd::DataType::FLOAT32); auto xCast = x->cast(sd::DataType::FLOAT32);
//move to NWHC
/**
* TODO: TF has a permute to NWHC here:
* https://github.com/tensorflow/tensorflow/blob/ce34a83e03394492b1c4e5bb92fbd56da2ba7ce5/tensorflow/core/kernels/fused_batch_norm_op.cc#L137
*
* This should be done as well for us, but results are still off.
* Figure out differences.
*/
if(dataFormat) {
xCast.printShapeInfo("x cast shape info pre permute");
xCast = xCast.permute({0, 2, 3, 1});
xCast.printShapeInfo("x cast shape info post permute");
}
REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str()); REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str());
REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str()); REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str());
@ -126,14 +138,22 @@ namespace sd {
auto scaledVariance = ((*variance + epsilon).transform(transform::RSqrt) * (*scale)).cast(xAffected.dataType()); auto scaledVariance = ((*variance + epsilon).transform(transform::RSqrt) * (*scale)).cast(xAffected.dataType());
auto xScaled1 = xCentered * scaledVariance; auto xScaled1 = xCentered * scaledVariance;
auto xShifted1 = xScaled1 + *offset; auto xShifted1 = xScaled1 + *offset;
if(dataFormat) {
//need to reshape from matrix to 4d then permute the ordering due to NWHC ordering
auto reshaped = xShifted1.reshape(xCast.ordering(),xCast.getShapeAsVector());
reshaped.permutei({0,3,1,2});
y->assign(reshaped);
}
else //NWHC case
y->assign(xShifted1);
y->assign(xShifted1);
if(isTraining) { if(isTraining) {
delete mean; delete mean;
delete variance; delete variance;
} }
return Status::OK(); return Status::OK();
} }

View File

@ -88,6 +88,7 @@ import static org.nd4j.imports.tfgraphs.TFGraphsSkipNodes.skipNode;
@Slf4j @Slf4j
public class TFGraphTestAllHelper { public class TFGraphTestAllHelper {
public static final String resourceFolderVar = "DL4J_TEST_RESOURCES"; public static final String resourceFolderVar = "DL4J_TEST_RESOURCES";
public static TensorflowFrameworkImporter tensorflowFrameworkImporter = new TensorflowFrameworkImporter();
public enum ExecuteWith { public enum ExecuteWith {
SAMEDIFF, LIBND4J, JUST_PRINT SAMEDIFF, LIBND4J, JUST_PRINT
@ -103,7 +104,6 @@ public class TFGraphTestAllHelper {
e.printStackTrace(); e.printStackTrace();
} }
TensorflowFrameworkImporter tensorflowFrameworkImporter = new TensorflowFrameworkImporter();
return tensorflowFrameworkImporter.runImport(file.getAbsolutePath(),Collections.emptyMap()); return tensorflowFrameworkImporter.runImport(file.getAbsolutePath(),Collections.emptyMap());
} }
} }

View File

@ -76,30 +76,15 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
"layers_dropout/rank3_d05_train_mask1", "layers_dropout/rank3_d05_train_mask1",
"layers_dropout/rank2_d09_train", "layers_dropout/rank2_d09_train",
"layers_dropout/rank2_d05_train",*/ "layers_dropout/rank2_d05_train",*/
/* "primitive_gru_dynamic",
"layers_dropout/rank4_d05_train", "compare_and_bitpack/bool",
"fused_batch_norm/float16_nhwc",
"rnn/lstmblockcell/dynamic_b1_n5-3_ts4_noPH_noClip_fB1_noIS_withTM",
"rnn/lstmcell/dynamic_b1_nIn5_nOut3_ts4_noPH_noClip_fB1_Tanh_noIS_float_withTM",
"rnn/grublockcellv2/dynamic_b1_n3-2_ts1_noIS_noTM"*/
/* "unsorted_segment/unsorted_segment_mean_rank3",
"unsorted_segment/unsorted_segment_sqrt_n_rank2",
"unsorted_segment/unsorted_segment_mean_rank2",
"unsorted_segment/unsorted_segment_mean_rank3",
"unsorted_segment/unsorted_segment_sum_rank3",
"unsorted_segment/unsorted_segment_min_rank2",
"unsorted_segment/unsorted_segment_prod_rank2",
"unsorted_segment/unsorted_segment_max_rank2",*/
"bincount/rank0_weights",
"bincount/rank2_weights"
/* "compare_and_bitpack/bool",
"compare_and_bitpack/float32", "compare_and_bitpack/float32",
"compare_and_bitpack/float64", "compare_and_bitpack/float64",
"compare_and_bitpack/half", "compare_and_bitpack/half",
"compare_and_bitpack/int32", "compare_and_bitpack/int32",
"compare_and_bitpack/int8", "compare_and_bitpack/int8",
"compare_and_bitpack/int64", "compare_and_bitpack/int64",
"compare_and_bitpack/int16"*/ "compare_and_bitpack/int16"

View File

@ -64,6 +64,10 @@ public class ValidateZooModelPredictions extends BaseNd4jTest {
Nd4j.getRandom().setSeed(123); Nd4j.getRandom().setSeed(123);
} }
@Override
public long getTimeoutMilliseconds() {
return Long.MAX_VALUE;
}
@Test @Test
public void testMobilenetV1() throws Exception { public void testMobilenetV1() throws Exception {

View File

@ -489,7 +489,6 @@ val compareAndBitPack = TensorflowMappingProcess(
opName = "compare_and_bitpack", opName = "compare_and_bitpack",
opMappingRegistry = tensorflowOpRegistry, opMappingRegistry = tensorflowOpRegistry,
inputFrameworkOpName = "CompareAndBitpack", inputFrameworkOpName = "CompareAndBitpack",
attributeMappingRules = listOf(valueMapping(mutableMapOf("dtype" to "T"))),
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input","y" to "threshold"))) tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input","y" to "threshold")))
) )

View File

@ -9137,18 +9137,6 @@ mappings {
ruleType: "tensor" ruleType: "tensor"
inputFrameworkOpName: "CompareAndBitpack" inputFrameworkOpName: "CompareAndBitpack"
} }
rule {
ruleName: "valuemapping"
functionName: "valuemapping"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "dtype"
value: "T"
}
ruleType: "attribute"
inputFrameworkOpName: "CompareAndBitpack"
}
} }
mappings { mappings {
frameworkName: "tensorflow" frameworkName: "tensorflow"

View File

@ -24,38 +24,32 @@ import junit.framework.Assert.assertEquals
import junit.framework.Assert.assertTrue import junit.framework.Assert.assertTrue
import org.apache.commons.io.FileUtils import org.apache.commons.io.FileUtils
import org.apache.commons.io.IOUtils import org.apache.commons.io.IOUtils
import org.junit.Assert
import org.junit.Ignore import org.junit.Ignore
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.nd4j.autodiff.samediff.SameDiff
import org.nd4j.common.io.ClassPathResource import org.nd4j.common.io.ClassPathResource
import org.nd4j.imports.graphmapper.tf.TFGraphMapper import org.nd4j.imports.graphmapper.tf.TFGraphMapper
import org.nd4j.ir.OpNamespace import org.nd4j.ir.OpNamespace
import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ops.DynamicCustomOp import org.nd4j.linalg.api.ops.DynamicCustomOp
import org.nd4j.linalg.api.ops.custom.Roll
import org.nd4j.linalg.api.ops.impl.transforms.BinCount
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt
import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.profiler.ProfilerConfig import org.nd4j.linalg.profiler.ProfilerConfig
import org.nd4j.samediff.frameworkimport.ImportGraph import org.nd4j.samediff.frameworkimport.ImportGraph
import org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoaderHolder import org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoaderHolder
import org.nd4j.samediff.frameworkimport.registry.OpMappingRegistry
import org.nd4j.samediff.frameworkimport.registry.OpRegistryHolder import org.nd4j.samediff.frameworkimport.registry.OpRegistryHolder
import org.nd4j.samediff.frameworkimport.tensorflow.context.TensorflowMappingContext import org.nd4j.samediff.frameworkimport.tensorflow.context.TensorflowMappingContext
import org.nd4j.samediff.frameworkimport.tensorflow.definitions.registry import org.nd4j.samediff.frameworkimport.tensorflow.definitions.registry
import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFrameworkImporter import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFrameworkImporter
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRNode
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRTensor
import org.nd4j.shade.protobuf.ByteString
import org.nd4j.shade.protobuf.TextFormat import org.nd4j.shade.protobuf.TextFormat
import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner
import org.tensorflow.framework.* import org.tensorflow.framework.*
import java.io.File import java.io.File
import java.lang.IllegalStateException
import java.nio.charset.Charset import java.nio.charset.Charset
import kotlin.math.max import java.nio.charset.StandardCharsets
import java.util.*
import kotlin.collections.HashMap
import kotlin.collections.HashSet
data class GraphInput(val graphDef: GraphDef,val inputNames: List<String>,val outputNames: List<String>, data class GraphInput(val graphDef: GraphDef,val inputNames: List<String>,val outputNames: List<String>,
val inputArrays: Map<String,INDArray>,val dynamicArrays: Map<String,INDArray>) val inputArrays: Map<String,INDArray>,val dynamicArrays: Map<String,INDArray>)
@ -78,6 +72,7 @@ class TestTensorflowIR {
fun manualTest() { fun manualTest() {
val manualGraph = FileUtils.readFileToString(File("test.pbtxt"),Charset.defaultCharset()) val manualGraph = FileUtils.readFileToString(File("test.pbtxt"),Charset.defaultCharset())
val parsedGraph = GraphDef.newBuilder() val parsedGraph = GraphDef.newBuilder()
//C:\Users\agibs\.nd4jtests\resnetv2_imagenet_frozen_graph
TextFormat.merge(manualGraph,parsedGraph) TextFormat.merge(manualGraph,parsedGraph)
val textGraph = parsedGraph.build() val textGraph = parsedGraph.build()
println(textGraph) println(textGraph)
@ -155,6 +150,150 @@ class TestTensorflowIR {
//assertEquals(tfOutput,output) //assertEquals(tfOutput,output)
} }
@Test
@Ignore
fun manualTestBinary() {
val path = "C:\\Users\\agibs\\.nd4jtests\\resnetv2_imagenet_frozen_graph\\resnetv2_imagenet_frozen_graph.pb"
val bytes = FileUtils.readFileToByteArray(File(path))
val parsedGraph = GraphDef.parseFrom(bytes)
val tfImporter = TensorflowFrameworkImporter()
//with names [image] and shapes {image=[4, 2, 28, 28, 3]}
Nd4j.getEnvironment().isDebug = true
Nd4j.getEnvironment().isVerbose = true
//TFGraphMapper.importGraph(textGraph)
// val inputMap = mapOf("input_1" to Nd4j.zeros(10).castTo(org.nd4j.linalg.api.buffer.DataType.INT32),"input_2" to Nd4j.zeros(1,8).castTo(org.nd4j.linalg.api.buffer.DataType.DOUBLE))
//val inputMap = mapOf("image" to Nd4j.ones(1,128,128,4))
/**
* TODO: fix emptyReduce.
* When we pass in 2 inputs where input 1 is the dimensions, the results
* work. In our model import, it appears that
* the empty dimensions aren't being passed down
* for int arguments properly.
* We need to figure out the difference between specifying 2 input arrays
* and ints, that or we need to make it so that arrays can be passed in
* for dimensions for each singular reduce op.
*
* Each op seems to be able to take in dimensions for indices.
* It *MIGHT* be better just to pass in dimensions directly.
*/
//Load data
//Because we don't have DataVec NativeImageLoader in ND4J tests due to circular dependencies, we'll load the image previously saved...
var imgFile =
File("goldenretriever_rgb224_unnormalized_nchw_INDArray.bin")
var img = Nd4j.readBinary(imgFile).castTo(org.nd4j.linalg.api.buffer.DataType.FLOAT)
img = img.permute(0, 2, 3, 1).dup() //to NHWC
//Perform inference
//Resnet v2 - NO external normalization, just resize and center crop
// https://github.com/tensorflow/models/blob/d32d957a02f5cffb745a4da0d78f8432e2c52fd4/research/tensorrt/tensorrt.py#L70
// https://github.com/tensorflow/models/blob/1af55e018eebce03fb61bba9959a04672536107d/official/resnet/imagenet_preprocessing.py#L253-L256
val importedGraph = TFGraphMapper.importGraph(parsedGraph)
//Load labels
val labels = labels()
//Perform inference
val inputs: List<String> = importedGraph.inputs()
Assert.assertEquals(1, inputs.size.toLong())
val out = "softmax_tensor"
val m: Map<String, INDArray> = importedGraph.output(Collections.singletonMap(inputs[0], img), out)
val outArr = m[out]
println("SHAPE: " + Arrays.toString(outArr!!.shape()))
println(outArr)
val argmax = outArr!!.argMax(1)
//Load labels
val classIdx = argmax.getInt(0)
val className = labels[classIdx]
val expClass = "golden retriever"
val prob = outArr!!.getDouble(classIdx.toLong())
println("Predicted class: $classIdx - \"$className\" - probability = $prob")
Assert.assertEquals(expClass, className)
val inputMap = Collections.singletonMap(inputs[0], img)
val tensorflowIRGraph = TensorflowIRGraph(parsedGraph,tensorflowOps,tfImporter.registry)
val outputList = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }.toMutableSet()
val tfGraphRunner = TensorflowIRGraphRunner(tensorflowIRGraph, inputMap.keys.toList(),listOf("batch_normalization/FusedBatchNorm",out))
val graph = tfImporter.importFromGraph(parsedGraph,inputMap)
val tfOutput = tfGraphRunner.run(inputMap)
/**
* TODO: UnsortedSegmentSum ,Solution is almost there, need to figure out how to
* output correct shape.
*
* Shape in TF is 5 x 5 but actual real output seems to be 1 x 10.
* We need to change the output shape to work like TF does.
*/
val output2 = importedGraph.outputAll(inputMap)
val output = graph.outputAll(inputMap)
//assertEquals(tfOutput.keys,outputList)
//assertEquals(tfOutput.keys,output2.keys)
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }
val skipValidation = setOf("parallel_stack/ExpandDims/dim")
//assertEquals(output.keys,output2.keys)
val notEquals = LinkedHashSet<String>()
val notEqualsTf = LinkedHashSet<String>()
val notEqualsOp = LinkedHashSet<String>()
names.forEach {
val value = output[it]
val value2 = output2[it]
val tfValue = tfOutput[it]
if(value!! != (value2!!)) {
val oldOps = importedGraph.ops[it]
val newOps = graph.ops[it]
val oldVar = importedGraph.variables[it]
val newVar = graph.variables[it]
notEquals.add(it)
}
if(tfValue != null && tfValue!! != (value!!)) {
val oldOps = importedGraph.ops[it]
val newOps = graph.ops[it]
val oldVar = importedGraph.variables[it]
val newVar = graph.variables[it]
notEqualsTf.add(it)
}
val oldOp = importedGraph.ops[it]
val newOp = graph.ops[it]
if(oldOp != newOp) {
notEqualsOp.add(it)
}
}
println(notEquals)
println(notEqualsTf)
println("Not equals ops $notEqualsOp")
println()
// assertEquals(output,output2)
//assertEquals(tfOutput,output)
}
@Throws(Exception::class)
fun labels(): List<String?> {
val labelsFile =
File("imagenet_labellist.txt")
return FileUtils.readLines(labelsFile, StandardCharsets.UTF_8)
}
@Test @Test
@Ignore @Ignore
fun manualTest2() { fun manualTest2() {

View File

@ -9137,18 +9137,6 @@ mappings {
ruleType: "tensor" ruleType: "tensor"
inputFrameworkOpName: "CompareAndBitpack" inputFrameworkOpName: "CompareAndBitpack"
} }
rule {
ruleName: "valuemapping"
functionName: "valuemapping"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "dtype"
value: "T"
}
ruleType: "attribute"
inputFrameworkOpName: "CompareAndBitpack"
}
} }
mappings { mappings {
frameworkName: "tensorflow" frameworkName: "tensorflow"