Fix NCHW case for fused batch norm
parent
e88d0fe96c
commit
8bc3172e40
|
@ -1578,20 +1578,20 @@ namespace sd {
|
|||
int rank = shape::rank(_shapeInfo);
|
||||
int lim = shape::shapeInfoLength(rank);
|
||||
|
||||
if(msg != nullptr)
|
||||
printf("shapeInfo %s: [", msg);
|
||||
else
|
||||
printf("shapeInfo: [");
|
||||
|
||||
printf("%i, ", rank);
|
||||
if(msg != nullptr) {
|
||||
nd4j_printf("shapeInfo %s: [", msg);
|
||||
} else {
|
||||
nd4j_printf("shapeInfo: [%s", "");
|
||||
}
|
||||
nd4j_printf("%i, ", rank);
|
||||
for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){
|
||||
if(i == rank + 1)
|
||||
printf(" ");
|
||||
printf("%lld,", _shapeInfo[i]);
|
||||
nd4j_printf(" ","");
|
||||
nd4j_printf("%lld,", _shapeInfo[i]);
|
||||
}
|
||||
printf(" %lld,", shape::type(_shapeInfo));
|
||||
printf("%lld,", shape::elementWiseStride(_shapeInfo));
|
||||
printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
|
||||
nd4j_printf(" %lld,", shape::type(_shapeInfo));
|
||||
nd4j_printf("%lld,", shape::elementWiseStride(_shapeInfo));
|
||||
nd4j_printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
|
|
|
@ -45,6 +45,7 @@ namespace sd {
|
|||
|
||||
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
|
||||
const bool isTraining = (bool)INT_ARG(1);
|
||||
nd4j_debug("CUSTOM_OP fused_batch_norm: data format, is NCHW: %d, isTraining: %d\n",dataFormat,isTraining);
|
||||
|
||||
REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf());
|
||||
|
||||
|
@ -62,8 +63,19 @@ namespace sd {
|
|||
}
|
||||
|
||||
auto xCast = x->cast(sd::DataType::FLOAT32);
|
||||
|
||||
|
||||
//move to NWHC
|
||||
/**
|
||||
* TODO: TF has a permute to NWHC here:
|
||||
* https://github.com/tensorflow/tensorflow/blob/ce34a83e03394492b1c4e5bb92fbd56da2ba7ce5/tensorflow/core/kernels/fused_batch_norm_op.cc#L137
|
||||
*
|
||||
* This should be done as well for us, but results are still off.
|
||||
* Figure out differences.
|
||||
*/
|
||||
if(dataFormat) {
|
||||
xCast.printShapeInfo("x cast shape info pre permute");
|
||||
xCast = xCast.permute({0, 2, 3, 1});
|
||||
xCast.printShapeInfo("x cast shape info post permute");
|
||||
}
|
||||
REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str());
|
||||
REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str());
|
||||
|
||||
|
@ -126,14 +138,22 @@ namespace sd {
|
|||
auto scaledVariance = ((*variance + epsilon).transform(transform::RSqrt) * (*scale)).cast(xAffected.dataType());
|
||||
auto xScaled1 = xCentered * scaledVariance;
|
||||
auto xShifted1 = xScaled1 + *offset;
|
||||
if(dataFormat) {
|
||||
//need to reshape from matrix to 4d then permute the ordering due to NWHC ordering
|
||||
auto reshaped = xShifted1.reshape(xCast.ordering(),xCast.getShapeAsVector());
|
||||
reshaped.permutei({0,3,1,2});
|
||||
y->assign(reshaped);
|
||||
|
||||
}
|
||||
else //NWHC case
|
||||
y->assign(xShifted1);
|
||||
|
||||
y->assign(xShifted1);
|
||||
|
||||
if(isTraining) {
|
||||
delete mean;
|
||||
delete variance;
|
||||
}
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -88,6 +88,7 @@ import static org.nd4j.imports.tfgraphs.TFGraphsSkipNodes.skipNode;
|
|||
@Slf4j
|
||||
public class TFGraphTestAllHelper {
|
||||
public static final String resourceFolderVar = "DL4J_TEST_RESOURCES";
|
||||
public static TensorflowFrameworkImporter tensorflowFrameworkImporter = new TensorflowFrameworkImporter();
|
||||
|
||||
public enum ExecuteWith {
|
||||
SAMEDIFF, LIBND4J, JUST_PRINT
|
||||
|
@ -103,7 +104,6 @@ public class TFGraphTestAllHelper {
|
|||
e.printStackTrace();
|
||||
}
|
||||
|
||||
TensorflowFrameworkImporter tensorflowFrameworkImporter = new TensorflowFrameworkImporter();
|
||||
return tensorflowFrameworkImporter.runImport(file.getAbsolutePath(),Collections.emptyMap());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -76,30 +76,15 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
"layers_dropout/rank3_d05_train_mask1",
|
||||
"layers_dropout/rank2_d09_train",
|
||||
"layers_dropout/rank2_d05_train",*/
|
||||
/* "primitive_gru_dynamic",
|
||||
"layers_dropout/rank4_d05_train",
|
||||
"fused_batch_norm/float16_nhwc",
|
||||
"rnn/lstmblockcell/dynamic_b1_n5-3_ts4_noPH_noClip_fB1_noIS_withTM",
|
||||
"rnn/lstmcell/dynamic_b1_nIn5_nOut3_ts4_noPH_noClip_fB1_Tanh_noIS_float_withTM",
|
||||
"rnn/grublockcellv2/dynamic_b1_n3-2_ts1_noIS_noTM"*/
|
||||
/* "unsorted_segment/unsorted_segment_mean_rank3",
|
||||
"unsorted_segment/unsorted_segment_sqrt_n_rank2",
|
||||
"unsorted_segment/unsorted_segment_mean_rank2",
|
||||
"unsorted_segment/unsorted_segment_mean_rank3",
|
||||
"unsorted_segment/unsorted_segment_sum_rank3",
|
||||
"unsorted_segment/unsorted_segment_min_rank2",
|
||||
"unsorted_segment/unsorted_segment_prod_rank2",
|
||||
"unsorted_segment/unsorted_segment_max_rank2",*/
|
||||
"bincount/rank0_weights",
|
||||
"bincount/rank2_weights"
|
||||
/* "compare_and_bitpack/bool",
|
||||
|
||||
"compare_and_bitpack/bool",
|
||||
"compare_and_bitpack/float32",
|
||||
"compare_and_bitpack/float64",
|
||||
"compare_and_bitpack/half",
|
||||
"compare_and_bitpack/int32",
|
||||
"compare_and_bitpack/int8",
|
||||
"compare_and_bitpack/int64",
|
||||
"compare_and_bitpack/int16"*/
|
||||
"compare_and_bitpack/int16"
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -64,6 +64,10 @@ public class ValidateZooModelPredictions extends BaseNd4jTest {
|
|||
Nd4j.getRandom().setSeed(123);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return Long.MAX_VALUE;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMobilenetV1() throws Exception {
|
||||
|
|
|
@ -489,7 +489,6 @@ val compareAndBitPack = TensorflowMappingProcess(
|
|||
opName = "compare_and_bitpack",
|
||||
opMappingRegistry = tensorflowOpRegistry,
|
||||
inputFrameworkOpName = "CompareAndBitpack",
|
||||
attributeMappingRules = listOf(valueMapping(mutableMapOf("dtype" to "T"))),
|
||||
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input","y" to "threshold")))
|
||||
)
|
||||
|
||||
|
|
|
@ -9137,18 +9137,6 @@ mappings {
|
|||
ruleType: "tensor"
|
||||
inputFrameworkOpName: "CompareAndBitpack"
|
||||
}
|
||||
rule {
|
||||
ruleName: "valuemapping"
|
||||
functionName: "valuemapping"
|
||||
inputDataTypeName: "T"
|
||||
outputDataTypeName: "dtype"
|
||||
inputToOutput {
|
||||
key: "dtype"
|
||||
value: "T"
|
||||
}
|
||||
ruleType: "attribute"
|
||||
inputFrameworkOpName: "CompareAndBitpack"
|
||||
}
|
||||
}
|
||||
mappings {
|
||||
frameworkName: "tensorflow"
|
||||
|
|
|
@ -24,38 +24,32 @@ import junit.framework.Assert.assertEquals
|
|||
import junit.framework.Assert.assertTrue
|
||||
import org.apache.commons.io.FileUtils
|
||||
import org.apache.commons.io.IOUtils
|
||||
import org.junit.Assert
|
||||
import org.junit.Ignore
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.nd4j.autodiff.samediff.SameDiff
|
||||
import org.nd4j.common.io.ClassPathResource
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper
|
||||
import org.nd4j.ir.OpNamespace
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp
|
||||
import org.nd4j.linalg.api.ops.custom.Roll
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.BinCount
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import org.nd4j.linalg.profiler.ProfilerConfig
|
||||
import org.nd4j.samediff.frameworkimport.ImportGraph
|
||||
import org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoaderHolder
|
||||
import org.nd4j.samediff.frameworkimport.registry.OpMappingRegistry
|
||||
import org.nd4j.samediff.frameworkimport.registry.OpRegistryHolder
|
||||
import org.nd4j.samediff.frameworkimport.tensorflow.context.TensorflowMappingContext
|
||||
import org.nd4j.samediff.frameworkimport.tensorflow.definitions.registry
|
||||
import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFrameworkImporter
|
||||
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph
|
||||
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner
|
||||
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRNode
|
||||
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRTensor
|
||||
import org.nd4j.shade.protobuf.ByteString
|
||||
import org.nd4j.shade.protobuf.TextFormat
|
||||
import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner
|
||||
import org.tensorflow.framework.*
|
||||
import java.io.File
|
||||
import java.lang.IllegalStateException
|
||||
import java.nio.charset.Charset
|
||||
import kotlin.math.max
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.util.*
|
||||
import kotlin.collections.HashMap
|
||||
import kotlin.collections.HashSet
|
||||
|
||||
data class GraphInput(val graphDef: GraphDef,val inputNames: List<String>,val outputNames: List<String>,
|
||||
val inputArrays: Map<String,INDArray>,val dynamicArrays: Map<String,INDArray>)
|
||||
|
@ -78,6 +72,7 @@ class TestTensorflowIR {
|
|||
fun manualTest() {
|
||||
val manualGraph = FileUtils.readFileToString(File("test.pbtxt"),Charset.defaultCharset())
|
||||
val parsedGraph = GraphDef.newBuilder()
|
||||
//C:\Users\agibs\.nd4jtests\resnetv2_imagenet_frozen_graph
|
||||
TextFormat.merge(manualGraph,parsedGraph)
|
||||
val textGraph = parsedGraph.build()
|
||||
println(textGraph)
|
||||
|
@ -155,6 +150,150 @@ class TestTensorflowIR {
|
|||
//assertEquals(tfOutput,output)
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
fun manualTestBinary() {
|
||||
val path = "C:\\Users\\agibs\\.nd4jtests\\resnetv2_imagenet_frozen_graph\\resnetv2_imagenet_frozen_graph.pb"
|
||||
val bytes = FileUtils.readFileToByteArray(File(path))
|
||||
val parsedGraph = GraphDef.parseFrom(bytes)
|
||||
val tfImporter = TensorflowFrameworkImporter()
|
||||
//with names [image] and shapes {image=[4, 2, 28, 28, 3]}
|
||||
Nd4j.getEnvironment().isDebug = true
|
||||
Nd4j.getEnvironment().isVerbose = true
|
||||
//TFGraphMapper.importGraph(textGraph)
|
||||
// val inputMap = mapOf("input_1" to Nd4j.zeros(10).castTo(org.nd4j.linalg.api.buffer.DataType.INT32),"input_2" to Nd4j.zeros(1,8).castTo(org.nd4j.linalg.api.buffer.DataType.DOUBLE))
|
||||
//val inputMap = mapOf("image" to Nd4j.ones(1,128,128,4))
|
||||
/**
|
||||
* TODO: fix emptyReduce.
|
||||
* When we pass in 2 inputs where input 1 is the dimensions, the results
|
||||
* work. In our model import, it appears that
|
||||
* the empty dimensions aren't being passed down
|
||||
* for int arguments properly.
|
||||
* We need to figure out the difference between specifying 2 input arrays
|
||||
* and ints, that or we need to make it so that arrays can be passed in
|
||||
* for dimensions for each singular reduce op.
|
||||
*
|
||||
* Each op seems to be able to take in dimensions for indices.
|
||||
* It *MIGHT* be better just to pass in dimensions directly.
|
||||
*/
|
||||
|
||||
|
||||
//Load data
|
||||
//Because we don't have DataVec NativeImageLoader in ND4J tests due to circular dependencies, we'll load the image previously saved...
|
||||
var imgFile =
|
||||
File("goldenretriever_rgb224_unnormalized_nchw_INDArray.bin")
|
||||
var img = Nd4j.readBinary(imgFile).castTo(org.nd4j.linalg.api.buffer.DataType.FLOAT)
|
||||
img = img.permute(0, 2, 3, 1).dup() //to NHWC
|
||||
|
||||
|
||||
//Perform inference
|
||||
|
||||
//Resnet v2 - NO external normalization, just resize and center crop
|
||||
// https://github.com/tensorflow/models/blob/d32d957a02f5cffb745a4da0d78f8432e2c52fd4/research/tensorrt/tensorrt.py#L70
|
||||
// https://github.com/tensorflow/models/blob/1af55e018eebce03fb61bba9959a04672536107d/official/resnet/imagenet_preprocessing.py#L253-L256
|
||||
|
||||
val importedGraph = TFGraphMapper.importGraph(parsedGraph)
|
||||
|
||||
//Load labels
|
||||
val labels = labels()
|
||||
|
||||
|
||||
//Perform inference
|
||||
val inputs: List<String> = importedGraph.inputs()
|
||||
Assert.assertEquals(1, inputs.size.toLong())
|
||||
|
||||
val out = "softmax_tensor"
|
||||
val m: Map<String, INDArray> = importedGraph.output(Collections.singletonMap(inputs[0], img), out)
|
||||
|
||||
val outArr = m[out]
|
||||
|
||||
|
||||
println("SHAPE: " + Arrays.toString(outArr!!.shape()))
|
||||
println(outArr)
|
||||
|
||||
val argmax = outArr!!.argMax(1)
|
||||
|
||||
//Load labels
|
||||
|
||||
val classIdx = argmax.getInt(0)
|
||||
val className = labels[classIdx]
|
||||
val expClass = "golden retriever"
|
||||
val prob = outArr!!.getDouble(classIdx.toLong())
|
||||
|
||||
println("Predicted class: $classIdx - \"$className\" - probability = $prob")
|
||||
Assert.assertEquals(expClass, className)
|
||||
|
||||
val inputMap = Collections.singletonMap(inputs[0], img)
|
||||
val tensorflowIRGraph = TensorflowIRGraph(parsedGraph,tensorflowOps,tfImporter.registry)
|
||||
val outputList = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }.toMutableSet()
|
||||
val tfGraphRunner = TensorflowIRGraphRunner(tensorflowIRGraph, inputMap.keys.toList(),listOf("batch_normalization/FusedBatchNorm",out))
|
||||
val graph = tfImporter.importFromGraph(parsedGraph,inputMap)
|
||||
val tfOutput = tfGraphRunner.run(inputMap)
|
||||
|
||||
/**
|
||||
* TODO: UnsortedSegmentSum ,Solution is almost there, need to figure out how to
|
||||
* output correct shape.
|
||||
*
|
||||
* Shape in TF is 5 x 5 but actual real output seems to be 1 x 10.
|
||||
* We need to change the output shape to work like TF does.
|
||||
*/
|
||||
val output2 = importedGraph.outputAll(inputMap)
|
||||
val output = graph.outputAll(inputMap)
|
||||
|
||||
|
||||
//assertEquals(tfOutput.keys,outputList)
|
||||
//assertEquals(tfOutput.keys,output2.keys)
|
||||
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }
|
||||
val skipValidation = setOf("parallel_stack/ExpandDims/dim")
|
||||
//assertEquals(output.keys,output2.keys)
|
||||
val notEquals = LinkedHashSet<String>()
|
||||
val notEqualsTf = LinkedHashSet<String>()
|
||||
val notEqualsOp = LinkedHashSet<String>()
|
||||
names.forEach {
|
||||
val value = output[it]
|
||||
val value2 = output2[it]
|
||||
val tfValue = tfOutput[it]
|
||||
if(value!! != (value2!!)) {
|
||||
val oldOps = importedGraph.ops[it]
|
||||
val newOps = graph.ops[it]
|
||||
val oldVar = importedGraph.variables[it]
|
||||
val newVar = graph.variables[it]
|
||||
notEquals.add(it)
|
||||
}
|
||||
|
||||
if(tfValue != null && tfValue!! != (value!!)) {
|
||||
val oldOps = importedGraph.ops[it]
|
||||
val newOps = graph.ops[it]
|
||||
val oldVar = importedGraph.variables[it]
|
||||
val newVar = graph.variables[it]
|
||||
notEqualsTf.add(it)
|
||||
}
|
||||
|
||||
val oldOp = importedGraph.ops[it]
|
||||
val newOp = graph.ops[it]
|
||||
if(oldOp != newOp) {
|
||||
notEqualsOp.add(it)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
println(notEquals)
|
||||
println(notEqualsTf)
|
||||
println("Not equals ops $notEqualsOp")
|
||||
println()
|
||||
// assertEquals(output,output2)
|
||||
//assertEquals(tfOutput,output)
|
||||
}
|
||||
|
||||
@Throws(Exception::class)
|
||||
fun labels(): List<String?> {
|
||||
val labelsFile =
|
||||
File("imagenet_labellist.txt")
|
||||
return FileUtils.readLines(labelsFile, StandardCharsets.UTF_8)
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
fun manualTest2() {
|
||||
|
|
|
@ -9137,18 +9137,6 @@ mappings {
|
|||
ruleType: "tensor"
|
||||
inputFrameworkOpName: "CompareAndBitpack"
|
||||
}
|
||||
rule {
|
||||
ruleName: "valuemapping"
|
||||
functionName: "valuemapping"
|
||||
inputDataTypeName: "T"
|
||||
outputDataTypeName: "dtype"
|
||||
inputToOutput {
|
||||
key: "dtype"
|
||||
value: "T"
|
||||
}
|
||||
ruleType: "attribute"
|
||||
inputFrameworkOpName: "CompareAndBitpack"
|
||||
}
|
||||
}
|
||||
mappings {
|
||||
frameworkName: "tensorflow"
|
||||
|
|
Loading…
Reference in New Issue