Fix NCHW case for fused batch norm

master
agibsonccc 2021-02-16 11:02:27 +09:00
parent e88d0fe96c
commit 8bc3172e40
9 changed files with 193 additions and 70 deletions

View File

@ -1578,20 +1578,20 @@ namespace sd {
int rank = shape::rank(_shapeInfo);
int lim = shape::shapeInfoLength(rank);
if(msg != nullptr)
printf("shapeInfo %s: [", msg);
else
printf("shapeInfo: [");
printf("%i, ", rank);
if(msg != nullptr) {
nd4j_printf("shapeInfo %s: [", msg);
} else {
nd4j_printf("shapeInfo: [%s", "");
}
nd4j_printf("%i, ", rank);
for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){
if(i == rank + 1)
printf(" ");
printf("%lld,", _shapeInfo[i]);
nd4j_printf(" ","");
nd4j_printf("%lld,", _shapeInfo[i]);
}
printf(" %lld,", shape::type(_shapeInfo));
printf("%lld,", shape::elementWiseStride(_shapeInfo));
printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
nd4j_printf(" %lld,", shape::type(_shapeInfo));
nd4j_printf("%lld,", shape::elementWiseStride(_shapeInfo));
nd4j_printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
fflush(stdout);
}

View File

@ -45,6 +45,7 @@ namespace sd {
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
const bool isTraining = (bool)INT_ARG(1);
nd4j_debug("CUSTOM_OP fused_batch_norm: data format, is NCHW: %d, isTraining: %d\n",dataFormat,isTraining);
REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf());
@ -62,8 +63,19 @@ namespace sd {
}
auto xCast = x->cast(sd::DataType::FLOAT32);
//move to NWHC
/**
* TODO: TF has a permute to NWHC here:
* https://github.com/tensorflow/tensorflow/blob/ce34a83e03394492b1c4e5bb92fbd56da2ba7ce5/tensorflow/core/kernels/fused_batch_norm_op.cc#L137
*
* This should be done as well for us, but results are still off.
* Figure out differences.
*/
if(dataFormat) {
xCast.printShapeInfo("x cast shape info pre permute");
xCast = xCast.permute({0, 2, 3, 1});
xCast.printShapeInfo("x cast shape info post permute");
}
REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str());
REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str());
@ -126,9 +138,17 @@ namespace sd {
auto scaledVariance = ((*variance + epsilon).transform(transform::RSqrt) * (*scale)).cast(xAffected.dataType());
auto xScaled1 = xCentered * scaledVariance;
auto xShifted1 = xScaled1 + *offset;
if(dataFormat) {
//need to reshape from matrix to 4d then permute the ordering due to NWHC ordering
auto reshaped = xShifted1.reshape(xCast.ordering(),xCast.getShapeAsVector());
reshaped.permutei({0,3,1,2});
y->assign(reshaped);
}
else //NWHC case
y->assign(xShifted1);
if(isTraining) {
delete mean;
delete variance;

View File

@ -88,6 +88,7 @@ import static org.nd4j.imports.tfgraphs.TFGraphsSkipNodes.skipNode;
@Slf4j
public class TFGraphTestAllHelper {
public static final String resourceFolderVar = "DL4J_TEST_RESOURCES";
public static TensorflowFrameworkImporter tensorflowFrameworkImporter = new TensorflowFrameworkImporter();
public enum ExecuteWith {
SAMEDIFF, LIBND4J, JUST_PRINT
@ -103,7 +104,6 @@ public class TFGraphTestAllHelper {
e.printStackTrace();
}
TensorflowFrameworkImporter tensorflowFrameworkImporter = new TensorflowFrameworkImporter();
return tensorflowFrameworkImporter.runImport(file.getAbsolutePath(),Collections.emptyMap());
}
}

View File

@ -76,30 +76,15 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
"layers_dropout/rank3_d05_train_mask1",
"layers_dropout/rank2_d09_train",
"layers_dropout/rank2_d05_train",*/
/* "primitive_gru_dynamic",
"layers_dropout/rank4_d05_train",
"fused_batch_norm/float16_nhwc",
"rnn/lstmblockcell/dynamic_b1_n5-3_ts4_noPH_noClip_fB1_noIS_withTM",
"rnn/lstmcell/dynamic_b1_nIn5_nOut3_ts4_noPH_noClip_fB1_Tanh_noIS_float_withTM",
"rnn/grublockcellv2/dynamic_b1_n3-2_ts1_noIS_noTM"*/
/* "unsorted_segment/unsorted_segment_mean_rank3",
"unsorted_segment/unsorted_segment_sqrt_n_rank2",
"unsorted_segment/unsorted_segment_mean_rank2",
"unsorted_segment/unsorted_segment_mean_rank3",
"unsorted_segment/unsorted_segment_sum_rank3",
"unsorted_segment/unsorted_segment_min_rank2",
"unsorted_segment/unsorted_segment_prod_rank2",
"unsorted_segment/unsorted_segment_max_rank2",*/
"bincount/rank0_weights",
"bincount/rank2_weights"
/* "compare_and_bitpack/bool",
"compare_and_bitpack/bool",
"compare_and_bitpack/float32",
"compare_and_bitpack/float64",
"compare_and_bitpack/half",
"compare_and_bitpack/int32",
"compare_and_bitpack/int8",
"compare_and_bitpack/int64",
"compare_and_bitpack/int16"*/
"compare_and_bitpack/int16"

View File

@ -64,6 +64,10 @@ public class ValidateZooModelPredictions extends BaseNd4jTest {
Nd4j.getRandom().setSeed(123);
}
@Override
public long getTimeoutMilliseconds() {
return Long.MAX_VALUE;
}
@Test
public void testMobilenetV1() throws Exception {

View File

@ -489,7 +489,6 @@ val compareAndBitPack = TensorflowMappingProcess(
opName = "compare_and_bitpack",
opMappingRegistry = tensorflowOpRegistry,
inputFrameworkOpName = "CompareAndBitpack",
attributeMappingRules = listOf(valueMapping(mutableMapOf("dtype" to "T"))),
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input","y" to "threshold")))
)

View File

@ -9137,18 +9137,6 @@ mappings {
ruleType: "tensor"
inputFrameworkOpName: "CompareAndBitpack"
}
rule {
ruleName: "valuemapping"
functionName: "valuemapping"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "dtype"
value: "T"
}
ruleType: "attribute"
inputFrameworkOpName: "CompareAndBitpack"
}
}
mappings {
frameworkName: "tensorflow"

View File

@ -24,38 +24,32 @@ import junit.framework.Assert.assertEquals
import junit.framework.Assert.assertTrue
import org.apache.commons.io.FileUtils
import org.apache.commons.io.IOUtils
import org.junit.Assert
import org.junit.Ignore
import org.junit.jupiter.api.Test
import org.nd4j.autodiff.samediff.SameDiff
import org.nd4j.common.io.ClassPathResource
import org.nd4j.imports.graphmapper.tf.TFGraphMapper
import org.nd4j.ir.OpNamespace
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ops.DynamicCustomOp
import org.nd4j.linalg.api.ops.custom.Roll
import org.nd4j.linalg.api.ops.impl.transforms.BinCount
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.profiler.ProfilerConfig
import org.nd4j.samediff.frameworkimport.ImportGraph
import org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoaderHolder
import org.nd4j.samediff.frameworkimport.registry.OpMappingRegistry
import org.nd4j.samediff.frameworkimport.registry.OpRegistryHolder
import org.nd4j.samediff.frameworkimport.tensorflow.context.TensorflowMappingContext
import org.nd4j.samediff.frameworkimport.tensorflow.definitions.registry
import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFrameworkImporter
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRNode
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRTensor
import org.nd4j.shade.protobuf.ByteString
import org.nd4j.shade.protobuf.TextFormat
import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner
import org.tensorflow.framework.*
import java.io.File
import java.lang.IllegalStateException
import java.nio.charset.Charset
import kotlin.math.max
import java.nio.charset.StandardCharsets
import java.util.*
import kotlin.collections.HashMap
import kotlin.collections.HashSet
data class GraphInput(val graphDef: GraphDef,val inputNames: List<String>,val outputNames: List<String>,
val inputArrays: Map<String,INDArray>,val dynamicArrays: Map<String,INDArray>)
@ -78,6 +72,7 @@ class TestTensorflowIR {
fun manualTest() {
val manualGraph = FileUtils.readFileToString(File("test.pbtxt"),Charset.defaultCharset())
val parsedGraph = GraphDef.newBuilder()
//C:\Users\agibs\.nd4jtests\resnetv2_imagenet_frozen_graph
TextFormat.merge(manualGraph,parsedGraph)
val textGraph = parsedGraph.build()
println(textGraph)
@ -155,6 +150,150 @@ class TestTensorflowIR {
//assertEquals(tfOutput,output)
}
@Test
@Ignore
fun manualTestBinary() {
val path = "C:\\Users\\agibs\\.nd4jtests\\resnetv2_imagenet_frozen_graph\\resnetv2_imagenet_frozen_graph.pb"
val bytes = FileUtils.readFileToByteArray(File(path))
val parsedGraph = GraphDef.parseFrom(bytes)
val tfImporter = TensorflowFrameworkImporter()
//with names [image] and shapes {image=[4, 2, 28, 28, 3]}
Nd4j.getEnvironment().isDebug = true
Nd4j.getEnvironment().isVerbose = true
//TFGraphMapper.importGraph(textGraph)
// val inputMap = mapOf("input_1" to Nd4j.zeros(10).castTo(org.nd4j.linalg.api.buffer.DataType.INT32),"input_2" to Nd4j.zeros(1,8).castTo(org.nd4j.linalg.api.buffer.DataType.DOUBLE))
//val inputMap = mapOf("image" to Nd4j.ones(1,128,128,4))
/**
* TODO: fix emptyReduce.
* When we pass in 2 inputs where input 1 is the dimensions, the results
* work. In our model import, it appears that
* the empty dimensions aren't being passed down
* for int arguments properly.
* We need to figure out the difference between specifying 2 input arrays
* and ints, that or we need to make it so that arrays can be passed in
* for dimensions for each singular reduce op.
*
* Each op seems to be able to take in dimensions for indices.
* It *MIGHT* be better just to pass in dimensions directly.
*/
//Load data
//Because we don't have DataVec NativeImageLoader in ND4J tests due to circular dependencies, we'll load the image previously saved...
var imgFile =
File("goldenretriever_rgb224_unnormalized_nchw_INDArray.bin")
var img = Nd4j.readBinary(imgFile).castTo(org.nd4j.linalg.api.buffer.DataType.FLOAT)
img = img.permute(0, 2, 3, 1).dup() //to NHWC
//Perform inference
//Resnet v2 - NO external normalization, just resize and center crop
// https://github.com/tensorflow/models/blob/d32d957a02f5cffb745a4da0d78f8432e2c52fd4/research/tensorrt/tensorrt.py#L70
// https://github.com/tensorflow/models/blob/1af55e018eebce03fb61bba9959a04672536107d/official/resnet/imagenet_preprocessing.py#L253-L256
val importedGraph = TFGraphMapper.importGraph(parsedGraph)
//Load labels
val labels = labels()
//Perform inference
val inputs: List<String> = importedGraph.inputs()
Assert.assertEquals(1, inputs.size.toLong())
val out = "softmax_tensor"
val m: Map<String, INDArray> = importedGraph.output(Collections.singletonMap(inputs[0], img), out)
val outArr = m[out]
println("SHAPE: " + Arrays.toString(outArr!!.shape()))
println(outArr)
val argmax = outArr!!.argMax(1)
//Load labels
val classIdx = argmax.getInt(0)
val className = labels[classIdx]
val expClass = "golden retriever"
val prob = outArr!!.getDouble(classIdx.toLong())
println("Predicted class: $classIdx - \"$className\" - probability = $prob")
Assert.assertEquals(expClass, className)
val inputMap = Collections.singletonMap(inputs[0], img)
val tensorflowIRGraph = TensorflowIRGraph(parsedGraph,tensorflowOps,tfImporter.registry)
val outputList = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }.toMutableSet()
val tfGraphRunner = TensorflowIRGraphRunner(tensorflowIRGraph, inputMap.keys.toList(),listOf("batch_normalization/FusedBatchNorm",out))
val graph = tfImporter.importFromGraph(parsedGraph,inputMap)
val tfOutput = tfGraphRunner.run(inputMap)
/**
* TODO: UnsortedSegmentSum ,Solution is almost there, need to figure out how to
* output correct shape.
*
* Shape in TF is 5 x 5 but actual real output seems to be 1 x 10.
* We need to change the output shape to work like TF does.
*/
val output2 = importedGraph.outputAll(inputMap)
val output = graph.outputAll(inputMap)
//assertEquals(tfOutput.keys,outputList)
//assertEquals(tfOutput.keys,output2.keys)
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }
val skipValidation = setOf("parallel_stack/ExpandDims/dim")
//assertEquals(output.keys,output2.keys)
val notEquals = LinkedHashSet<String>()
val notEqualsTf = LinkedHashSet<String>()
val notEqualsOp = LinkedHashSet<String>()
names.forEach {
val value = output[it]
val value2 = output2[it]
val tfValue = tfOutput[it]
if(value!! != (value2!!)) {
val oldOps = importedGraph.ops[it]
val newOps = graph.ops[it]
val oldVar = importedGraph.variables[it]
val newVar = graph.variables[it]
notEquals.add(it)
}
if(tfValue != null && tfValue!! != (value!!)) {
val oldOps = importedGraph.ops[it]
val newOps = graph.ops[it]
val oldVar = importedGraph.variables[it]
val newVar = graph.variables[it]
notEqualsTf.add(it)
}
val oldOp = importedGraph.ops[it]
val newOp = graph.ops[it]
if(oldOp != newOp) {
notEqualsOp.add(it)
}
}
println(notEquals)
println(notEqualsTf)
println("Not equals ops $notEqualsOp")
println()
// assertEquals(output,output2)
//assertEquals(tfOutput,output)
}
@Throws(Exception::class)
fun labels(): List<String?> {
val labelsFile =
File("imagenet_labellist.txt")
return FileUtils.readLines(labelsFile, StandardCharsets.UTF_8)
}
@Test
@Ignore
fun manualTest2() {

View File

@ -9137,18 +9137,6 @@ mappings {
ruleType: "tensor"
inputFrameworkOpName: "CompareAndBitpack"
}
rule {
ruleName: "valuemapping"
functionName: "valuemapping"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "dtype"
value: "T"
}
ruleType: "attribute"
inputFrameworkOpName: "CompareAndBitpack"
}
}
mappings {
frameworkName: "tensorflow"