Fix NCHW case for fused batch norm
parent
e88d0fe96c
commit
8bc3172e40
|
@ -1578,20 +1578,20 @@ namespace sd {
|
||||||
int rank = shape::rank(_shapeInfo);
|
int rank = shape::rank(_shapeInfo);
|
||||||
int lim = shape::shapeInfoLength(rank);
|
int lim = shape::shapeInfoLength(rank);
|
||||||
|
|
||||||
if(msg != nullptr)
|
if(msg != nullptr) {
|
||||||
printf("shapeInfo %s: [", msg);
|
nd4j_printf("shapeInfo %s: [", msg);
|
||||||
else
|
} else {
|
||||||
printf("shapeInfo: [");
|
nd4j_printf("shapeInfo: [%s", "");
|
||||||
|
}
|
||||||
printf("%i, ", rank);
|
nd4j_printf("%i, ", rank);
|
||||||
for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){
|
for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){
|
||||||
if(i == rank + 1)
|
if(i == rank + 1)
|
||||||
printf(" ");
|
nd4j_printf(" ","");
|
||||||
printf("%lld,", _shapeInfo[i]);
|
nd4j_printf("%lld,", _shapeInfo[i]);
|
||||||
}
|
}
|
||||||
printf(" %lld,", shape::type(_shapeInfo));
|
nd4j_printf(" %lld,", shape::type(_shapeInfo));
|
||||||
printf("%lld,", shape::elementWiseStride(_shapeInfo));
|
nd4j_printf("%lld,", shape::elementWiseStride(_shapeInfo));
|
||||||
printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
|
nd4j_printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
|
||||||
|
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,6 +45,7 @@ namespace sd {
|
||||||
|
|
||||||
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
|
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
|
||||||
const bool isTraining = (bool)INT_ARG(1);
|
const bool isTraining = (bool)INT_ARG(1);
|
||||||
|
nd4j_debug("CUSTOM_OP fused_batch_norm: data format, is NCHW: %d, isTraining: %d\n",dataFormat,isTraining);
|
||||||
|
|
||||||
REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf());
|
REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf());
|
||||||
|
|
||||||
|
@ -62,8 +63,19 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto xCast = x->cast(sd::DataType::FLOAT32);
|
auto xCast = x->cast(sd::DataType::FLOAT32);
|
||||||
|
//move to NWHC
|
||||||
|
/**
|
||||||
|
* TODO: TF has a permute to NWHC here:
|
||||||
|
* https://github.com/tensorflow/tensorflow/blob/ce34a83e03394492b1c4e5bb92fbd56da2ba7ce5/tensorflow/core/kernels/fused_batch_norm_op.cc#L137
|
||||||
|
*
|
||||||
|
* This should be done as well for us, but results are still off.
|
||||||
|
* Figure out differences.
|
||||||
|
*/
|
||||||
|
if(dataFormat) {
|
||||||
|
xCast.printShapeInfo("x cast shape info pre permute");
|
||||||
|
xCast = xCast.permute({0, 2, 3, 1});
|
||||||
|
xCast.printShapeInfo("x cast shape info post permute");
|
||||||
|
}
|
||||||
REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str());
|
REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str());
|
||||||
REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str());
|
REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str());
|
||||||
|
|
||||||
|
@ -126,14 +138,22 @@ namespace sd {
|
||||||
auto scaledVariance = ((*variance + epsilon).transform(transform::RSqrt) * (*scale)).cast(xAffected.dataType());
|
auto scaledVariance = ((*variance + epsilon).transform(transform::RSqrt) * (*scale)).cast(xAffected.dataType());
|
||||||
auto xScaled1 = xCentered * scaledVariance;
|
auto xScaled1 = xCentered * scaledVariance;
|
||||||
auto xShifted1 = xScaled1 + *offset;
|
auto xShifted1 = xScaled1 + *offset;
|
||||||
|
if(dataFormat) {
|
||||||
|
//need to reshape from matrix to 4d then permute the ordering due to NWHC ordering
|
||||||
|
auto reshaped = xShifted1.reshape(xCast.ordering(),xCast.getShapeAsVector());
|
||||||
|
reshaped.permutei({0,3,1,2});
|
||||||
|
y->assign(reshaped);
|
||||||
|
|
||||||
|
}
|
||||||
|
else //NWHC case
|
||||||
|
y->assign(xShifted1);
|
||||||
|
|
||||||
y->assign(xShifted1);
|
|
||||||
|
|
||||||
if(isTraining) {
|
if(isTraining) {
|
||||||
delete mean;
|
delete mean;
|
||||||
delete variance;
|
delete variance;
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -88,6 +88,7 @@ import static org.nd4j.imports.tfgraphs.TFGraphsSkipNodes.skipNode;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TFGraphTestAllHelper {
|
public class TFGraphTestAllHelper {
|
||||||
public static final String resourceFolderVar = "DL4J_TEST_RESOURCES";
|
public static final String resourceFolderVar = "DL4J_TEST_RESOURCES";
|
||||||
|
public static TensorflowFrameworkImporter tensorflowFrameworkImporter = new TensorflowFrameworkImporter();
|
||||||
|
|
||||||
public enum ExecuteWith {
|
public enum ExecuteWith {
|
||||||
SAMEDIFF, LIBND4J, JUST_PRINT
|
SAMEDIFF, LIBND4J, JUST_PRINT
|
||||||
|
@ -103,7 +104,6 @@ public class TFGraphTestAllHelper {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorflowFrameworkImporter tensorflowFrameworkImporter = new TensorflowFrameworkImporter();
|
|
||||||
return tensorflowFrameworkImporter.runImport(file.getAbsolutePath(),Collections.emptyMap());
|
return tensorflowFrameworkImporter.runImport(file.getAbsolutePath(),Collections.emptyMap());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,30 +76,15 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
"layers_dropout/rank3_d05_train_mask1",
|
"layers_dropout/rank3_d05_train_mask1",
|
||||||
"layers_dropout/rank2_d09_train",
|
"layers_dropout/rank2_d09_train",
|
||||||
"layers_dropout/rank2_d05_train",*/
|
"layers_dropout/rank2_d05_train",*/
|
||||||
/* "primitive_gru_dynamic",
|
|
||||||
"layers_dropout/rank4_d05_train",
|
"compare_and_bitpack/bool",
|
||||||
"fused_batch_norm/float16_nhwc",
|
|
||||||
"rnn/lstmblockcell/dynamic_b1_n5-3_ts4_noPH_noClip_fB1_noIS_withTM",
|
|
||||||
"rnn/lstmcell/dynamic_b1_nIn5_nOut3_ts4_noPH_noClip_fB1_Tanh_noIS_float_withTM",
|
|
||||||
"rnn/grublockcellv2/dynamic_b1_n3-2_ts1_noIS_noTM"*/
|
|
||||||
/* "unsorted_segment/unsorted_segment_mean_rank3",
|
|
||||||
"unsorted_segment/unsorted_segment_sqrt_n_rank2",
|
|
||||||
"unsorted_segment/unsorted_segment_mean_rank2",
|
|
||||||
"unsorted_segment/unsorted_segment_mean_rank3",
|
|
||||||
"unsorted_segment/unsorted_segment_sum_rank3",
|
|
||||||
"unsorted_segment/unsorted_segment_min_rank2",
|
|
||||||
"unsorted_segment/unsorted_segment_prod_rank2",
|
|
||||||
"unsorted_segment/unsorted_segment_max_rank2",*/
|
|
||||||
"bincount/rank0_weights",
|
|
||||||
"bincount/rank2_weights"
|
|
||||||
/* "compare_and_bitpack/bool",
|
|
||||||
"compare_and_bitpack/float32",
|
"compare_and_bitpack/float32",
|
||||||
"compare_and_bitpack/float64",
|
"compare_and_bitpack/float64",
|
||||||
"compare_and_bitpack/half",
|
"compare_and_bitpack/half",
|
||||||
"compare_and_bitpack/int32",
|
"compare_and_bitpack/int32",
|
||||||
"compare_and_bitpack/int8",
|
"compare_and_bitpack/int8",
|
||||||
"compare_and_bitpack/int64",
|
"compare_and_bitpack/int64",
|
||||||
"compare_and_bitpack/int16"*/
|
"compare_and_bitpack/int16"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -64,6 +64,10 @@ public class ValidateZooModelPredictions extends BaseNd4jTest {
|
||||||
Nd4j.getRandom().setSeed(123);
|
Nd4j.getRandom().setSeed(123);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return Long.MAX_VALUE;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMobilenetV1() throws Exception {
|
public void testMobilenetV1() throws Exception {
|
||||||
|
|
|
@ -489,7 +489,6 @@ val compareAndBitPack = TensorflowMappingProcess(
|
||||||
opName = "compare_and_bitpack",
|
opName = "compare_and_bitpack",
|
||||||
opMappingRegistry = tensorflowOpRegistry,
|
opMappingRegistry = tensorflowOpRegistry,
|
||||||
inputFrameworkOpName = "CompareAndBitpack",
|
inputFrameworkOpName = "CompareAndBitpack",
|
||||||
attributeMappingRules = listOf(valueMapping(mutableMapOf("dtype" to "T"))),
|
|
||||||
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input","y" to "threshold")))
|
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input","y" to "threshold")))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -9137,18 +9137,6 @@ mappings {
|
||||||
ruleType: "tensor"
|
ruleType: "tensor"
|
||||||
inputFrameworkOpName: "CompareAndBitpack"
|
inputFrameworkOpName: "CompareAndBitpack"
|
||||||
}
|
}
|
||||||
rule {
|
|
||||||
ruleName: "valuemapping"
|
|
||||||
functionName: "valuemapping"
|
|
||||||
inputDataTypeName: "T"
|
|
||||||
outputDataTypeName: "dtype"
|
|
||||||
inputToOutput {
|
|
||||||
key: "dtype"
|
|
||||||
value: "T"
|
|
||||||
}
|
|
||||||
ruleType: "attribute"
|
|
||||||
inputFrameworkOpName: "CompareAndBitpack"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
mappings {
|
mappings {
|
||||||
frameworkName: "tensorflow"
|
frameworkName: "tensorflow"
|
||||||
|
|
|
@ -24,38 +24,32 @@ import junit.framework.Assert.assertEquals
|
||||||
import junit.framework.Assert.assertTrue
|
import junit.framework.Assert.assertTrue
|
||||||
import org.apache.commons.io.FileUtils
|
import org.apache.commons.io.FileUtils
|
||||||
import org.apache.commons.io.IOUtils
|
import org.apache.commons.io.IOUtils
|
||||||
|
import org.junit.Assert
|
||||||
import org.junit.Ignore
|
import org.junit.Ignore
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
import org.nd4j.autodiff.samediff.SameDiff
|
|
||||||
import org.nd4j.common.io.ClassPathResource
|
import org.nd4j.common.io.ClassPathResource
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper
|
||||||
import org.nd4j.ir.OpNamespace
|
import org.nd4j.ir.OpNamespace
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp
|
import org.nd4j.linalg.api.ops.DynamicCustomOp
|
||||||
import org.nd4j.linalg.api.ops.custom.Roll
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.BinCount
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt
|
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
import org.nd4j.linalg.profiler.ProfilerConfig
|
import org.nd4j.linalg.profiler.ProfilerConfig
|
||||||
import org.nd4j.samediff.frameworkimport.ImportGraph
|
import org.nd4j.samediff.frameworkimport.ImportGraph
|
||||||
import org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoaderHolder
|
import org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoaderHolder
|
||||||
import org.nd4j.samediff.frameworkimport.registry.OpMappingRegistry
|
|
||||||
import org.nd4j.samediff.frameworkimport.registry.OpRegistryHolder
|
import org.nd4j.samediff.frameworkimport.registry.OpRegistryHolder
|
||||||
import org.nd4j.samediff.frameworkimport.tensorflow.context.TensorflowMappingContext
|
import org.nd4j.samediff.frameworkimport.tensorflow.context.TensorflowMappingContext
|
||||||
import org.nd4j.samediff.frameworkimport.tensorflow.definitions.registry
|
import org.nd4j.samediff.frameworkimport.tensorflow.definitions.registry
|
||||||
import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFrameworkImporter
|
import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFrameworkImporter
|
||||||
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph
|
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph
|
||||||
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner
|
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner
|
||||||
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRNode
|
|
||||||
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRTensor
|
|
||||||
import org.nd4j.shade.protobuf.ByteString
|
|
||||||
import org.nd4j.shade.protobuf.TextFormat
|
import org.nd4j.shade.protobuf.TextFormat
|
||||||
import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner
|
|
||||||
import org.tensorflow.framework.*
|
import org.tensorflow.framework.*
|
||||||
import java.io.File
|
import java.io.File
|
||||||
import java.lang.IllegalStateException
|
|
||||||
import java.nio.charset.Charset
|
import java.nio.charset.Charset
|
||||||
import kotlin.math.max
|
import java.nio.charset.StandardCharsets
|
||||||
|
import java.util.*
|
||||||
|
import kotlin.collections.HashMap
|
||||||
|
import kotlin.collections.HashSet
|
||||||
|
|
||||||
data class GraphInput(val graphDef: GraphDef,val inputNames: List<String>,val outputNames: List<String>,
|
data class GraphInput(val graphDef: GraphDef,val inputNames: List<String>,val outputNames: List<String>,
|
||||||
val inputArrays: Map<String,INDArray>,val dynamicArrays: Map<String,INDArray>)
|
val inputArrays: Map<String,INDArray>,val dynamicArrays: Map<String,INDArray>)
|
||||||
|
@ -78,6 +72,7 @@ class TestTensorflowIR {
|
||||||
fun manualTest() {
|
fun manualTest() {
|
||||||
val manualGraph = FileUtils.readFileToString(File("test.pbtxt"),Charset.defaultCharset())
|
val manualGraph = FileUtils.readFileToString(File("test.pbtxt"),Charset.defaultCharset())
|
||||||
val parsedGraph = GraphDef.newBuilder()
|
val parsedGraph = GraphDef.newBuilder()
|
||||||
|
//C:\Users\agibs\.nd4jtests\resnetv2_imagenet_frozen_graph
|
||||||
TextFormat.merge(manualGraph,parsedGraph)
|
TextFormat.merge(manualGraph,parsedGraph)
|
||||||
val textGraph = parsedGraph.build()
|
val textGraph = parsedGraph.build()
|
||||||
println(textGraph)
|
println(textGraph)
|
||||||
|
@ -155,6 +150,150 @@ class TestTensorflowIR {
|
||||||
//assertEquals(tfOutput,output)
|
//assertEquals(tfOutput,output)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Ignore
|
||||||
|
fun manualTestBinary() {
|
||||||
|
val path = "C:\\Users\\agibs\\.nd4jtests\\resnetv2_imagenet_frozen_graph\\resnetv2_imagenet_frozen_graph.pb"
|
||||||
|
val bytes = FileUtils.readFileToByteArray(File(path))
|
||||||
|
val parsedGraph = GraphDef.parseFrom(bytes)
|
||||||
|
val tfImporter = TensorflowFrameworkImporter()
|
||||||
|
//with names [image] and shapes {image=[4, 2, 28, 28, 3]}
|
||||||
|
Nd4j.getEnvironment().isDebug = true
|
||||||
|
Nd4j.getEnvironment().isVerbose = true
|
||||||
|
//TFGraphMapper.importGraph(textGraph)
|
||||||
|
// val inputMap = mapOf("input_1" to Nd4j.zeros(10).castTo(org.nd4j.linalg.api.buffer.DataType.INT32),"input_2" to Nd4j.zeros(1,8).castTo(org.nd4j.linalg.api.buffer.DataType.DOUBLE))
|
||||||
|
//val inputMap = mapOf("image" to Nd4j.ones(1,128,128,4))
|
||||||
|
/**
|
||||||
|
* TODO: fix emptyReduce.
|
||||||
|
* When we pass in 2 inputs where input 1 is the dimensions, the results
|
||||||
|
* work. In our model import, it appears that
|
||||||
|
* the empty dimensions aren't being passed down
|
||||||
|
* for int arguments properly.
|
||||||
|
* We need to figure out the difference between specifying 2 input arrays
|
||||||
|
* and ints, that or we need to make it so that arrays can be passed in
|
||||||
|
* for dimensions for each singular reduce op.
|
||||||
|
*
|
||||||
|
* Each op seems to be able to take in dimensions for indices.
|
||||||
|
* It *MIGHT* be better just to pass in dimensions directly.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
//Load data
|
||||||
|
//Because we don't have DataVec NativeImageLoader in ND4J tests due to circular dependencies, we'll load the image previously saved...
|
||||||
|
var imgFile =
|
||||||
|
File("goldenretriever_rgb224_unnormalized_nchw_INDArray.bin")
|
||||||
|
var img = Nd4j.readBinary(imgFile).castTo(org.nd4j.linalg.api.buffer.DataType.FLOAT)
|
||||||
|
img = img.permute(0, 2, 3, 1).dup() //to NHWC
|
||||||
|
|
||||||
|
|
||||||
|
//Perform inference
|
||||||
|
|
||||||
|
//Resnet v2 - NO external normalization, just resize and center crop
|
||||||
|
// https://github.com/tensorflow/models/blob/d32d957a02f5cffb745a4da0d78f8432e2c52fd4/research/tensorrt/tensorrt.py#L70
|
||||||
|
// https://github.com/tensorflow/models/blob/1af55e018eebce03fb61bba9959a04672536107d/official/resnet/imagenet_preprocessing.py#L253-L256
|
||||||
|
|
||||||
|
val importedGraph = TFGraphMapper.importGraph(parsedGraph)
|
||||||
|
|
||||||
|
//Load labels
|
||||||
|
val labels = labels()
|
||||||
|
|
||||||
|
|
||||||
|
//Perform inference
|
||||||
|
val inputs: List<String> = importedGraph.inputs()
|
||||||
|
Assert.assertEquals(1, inputs.size.toLong())
|
||||||
|
|
||||||
|
val out = "softmax_tensor"
|
||||||
|
val m: Map<String, INDArray> = importedGraph.output(Collections.singletonMap(inputs[0], img), out)
|
||||||
|
|
||||||
|
val outArr = m[out]
|
||||||
|
|
||||||
|
|
||||||
|
println("SHAPE: " + Arrays.toString(outArr!!.shape()))
|
||||||
|
println(outArr)
|
||||||
|
|
||||||
|
val argmax = outArr!!.argMax(1)
|
||||||
|
|
||||||
|
//Load labels
|
||||||
|
|
||||||
|
val classIdx = argmax.getInt(0)
|
||||||
|
val className = labels[classIdx]
|
||||||
|
val expClass = "golden retriever"
|
||||||
|
val prob = outArr!!.getDouble(classIdx.toLong())
|
||||||
|
|
||||||
|
println("Predicted class: $classIdx - \"$className\" - probability = $prob")
|
||||||
|
Assert.assertEquals(expClass, className)
|
||||||
|
|
||||||
|
val inputMap = Collections.singletonMap(inputs[0], img)
|
||||||
|
val tensorflowIRGraph = TensorflowIRGraph(parsedGraph,tensorflowOps,tfImporter.registry)
|
||||||
|
val outputList = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }.toMutableSet()
|
||||||
|
val tfGraphRunner = TensorflowIRGraphRunner(tensorflowIRGraph, inputMap.keys.toList(),listOf("batch_normalization/FusedBatchNorm",out))
|
||||||
|
val graph = tfImporter.importFromGraph(parsedGraph,inputMap)
|
||||||
|
val tfOutput = tfGraphRunner.run(inputMap)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* TODO: UnsortedSegmentSum ,Solution is almost there, need to figure out how to
|
||||||
|
* output correct shape.
|
||||||
|
*
|
||||||
|
* Shape in TF is 5 x 5 but actual real output seems to be 1 x 10.
|
||||||
|
* We need to change the output shape to work like TF does.
|
||||||
|
*/
|
||||||
|
val output2 = importedGraph.outputAll(inputMap)
|
||||||
|
val output = graph.outputAll(inputMap)
|
||||||
|
|
||||||
|
|
||||||
|
//assertEquals(tfOutput.keys,outputList)
|
||||||
|
//assertEquals(tfOutput.keys,output2.keys)
|
||||||
|
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }
|
||||||
|
val skipValidation = setOf("parallel_stack/ExpandDims/dim")
|
||||||
|
//assertEquals(output.keys,output2.keys)
|
||||||
|
val notEquals = LinkedHashSet<String>()
|
||||||
|
val notEqualsTf = LinkedHashSet<String>()
|
||||||
|
val notEqualsOp = LinkedHashSet<String>()
|
||||||
|
names.forEach {
|
||||||
|
val value = output[it]
|
||||||
|
val value2 = output2[it]
|
||||||
|
val tfValue = tfOutput[it]
|
||||||
|
if(value!! != (value2!!)) {
|
||||||
|
val oldOps = importedGraph.ops[it]
|
||||||
|
val newOps = graph.ops[it]
|
||||||
|
val oldVar = importedGraph.variables[it]
|
||||||
|
val newVar = graph.variables[it]
|
||||||
|
notEquals.add(it)
|
||||||
|
}
|
||||||
|
|
||||||
|
if(tfValue != null && tfValue!! != (value!!)) {
|
||||||
|
val oldOps = importedGraph.ops[it]
|
||||||
|
val newOps = graph.ops[it]
|
||||||
|
val oldVar = importedGraph.variables[it]
|
||||||
|
val newVar = graph.variables[it]
|
||||||
|
notEqualsTf.add(it)
|
||||||
|
}
|
||||||
|
|
||||||
|
val oldOp = importedGraph.ops[it]
|
||||||
|
val newOp = graph.ops[it]
|
||||||
|
if(oldOp != newOp) {
|
||||||
|
notEqualsOp.add(it)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
println(notEquals)
|
||||||
|
println(notEqualsTf)
|
||||||
|
println("Not equals ops $notEqualsOp")
|
||||||
|
println()
|
||||||
|
// assertEquals(output,output2)
|
||||||
|
//assertEquals(tfOutput,output)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Throws(Exception::class)
|
||||||
|
fun labels(): List<String?> {
|
||||||
|
val labelsFile =
|
||||||
|
File("imagenet_labellist.txt")
|
||||||
|
return FileUtils.readLines(labelsFile, StandardCharsets.UTF_8)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
@Ignore
|
||||||
fun manualTest2() {
|
fun manualTest2() {
|
||||||
|
|
|
@ -9137,18 +9137,6 @@ mappings {
|
||||||
ruleType: "tensor"
|
ruleType: "tensor"
|
||||||
inputFrameworkOpName: "CompareAndBitpack"
|
inputFrameworkOpName: "CompareAndBitpack"
|
||||||
}
|
}
|
||||||
rule {
|
|
||||||
ruleName: "valuemapping"
|
|
||||||
functionName: "valuemapping"
|
|
||||||
inputDataTypeName: "T"
|
|
||||||
outputDataTypeName: "dtype"
|
|
||||||
inputToOutput {
|
|
||||||
key: "dtype"
|
|
||||||
value: "T"
|
|
||||||
}
|
|
||||||
ruleType: "attribute"
|
|
||||||
inputFrameworkOpName: "CompareAndBitpack"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
mappings {
|
mappings {
|
||||||
frameworkName: "tensorflow"
|
frameworkName: "tensorflow"
|
||||||
|
|
Loading…
Reference in New Issue