Add new clion rules, fix batch norml

master
agibsonccc 2021-02-09 07:44:23 +09:00
parent 968eaad2dd
commit 5bd386a4f9
8 changed files with 232 additions and 186 deletions

View File

@ -51,9 +51,9 @@ endif()
if(WIN32 AND NOT ANDROID)
get_property(dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES)
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wa,-mbig-obj")
endif()
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wa,-mbig-obj")
endif()
foreach(dir ${dirs})
message(STATUS "dir='${dir}'")
endforeach()
@ -161,8 +161,8 @@ if(SD_CUDA)
endif()
if (CUDA_FOUND)
message("CUDA include directory: ${CUDA_INCLUDE_DIRS}")
include_directories(${CUDA_INCLUDE_DIRS})
message("CUDA include directory: ${CUDA_INCLUDE_DIRS}")
include_directories(${CUDA_INCLUDE_DIRS})
message("CUDA found!")
if ("${SD_EXPERIMENTAL}" STREQUAL "yes")
message("Experimental mode ENABLED")
@ -218,7 +218,7 @@ if(SD_CUDA)
file(GLOB_RECURSE COMPILATION_UNITS false ../include/loops/cuda/compilation_units/*.cu.in
../include/ops/impl/compilation_units/*.cpp.in)
../include/ops/impl/compilation_units/*.cpp.in)
foreach(FL_ITEM ${COMPILATION_UNITS})
genCompilation(FL_ITEM)
@ -229,12 +229,12 @@ if(SD_CUDA)
file(GLOB_RECURSE CUSTOMOPS_CUDNN_SOURCES false ../include/ops/declarable/platform/cudnn/*.cu)
endif()
add_library(samediff_obj OBJECT ${LOOPS_SOURCES_CUDA} ${LEGACY_SOURCES}
add_library(samediff_obj OBJECT ${LOOPS_SOURCES_CUDA} ${LEGACY_SOURCES}
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
)
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
)
if (WIN32)
message("MSVC runtime for library: ${MSVC_RT_LIB}")
@ -266,10 +266,10 @@ if(SD_CUDA)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj /std:c++14")
endif()
target_link_libraries(${SD_LIBRARY_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda)
target_link_libraries(${SD_LIBRARY_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda)
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
endif(CUDA_FOUND)
elseif(SD_CPU)
@ -296,8 +296,8 @@ elseif(SD_CPU)
file(GLOB_RECURSE COMPILATION_UNITS false ../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in
../include/loops/cpu/compilation_units/*.cpp.in ../include/helpers/cpu/loops/*.cpp.in
../include/ops/impl/compilation_units/*.cpp.in)
../include/loops/cpu/compilation_units/*.cpp.in ../include/helpers/cpu/loops/*.cpp.in
../include/ops/impl/compilation_units/*.cpp.in)
foreach(FL_ITEM ${COMPILATION_UNITS})
genCompilation(FL_ITEM)
@ -312,29 +312,29 @@ elseif(SD_CPU)
if(SD_CHECK_VECTORIZATION)
set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES})
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES})
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
if (CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
set(CHECK_VECT_FLAGS "-ftree-vectorize -fsave-optimization-record")
#to process fsave-optimization-record we will need our cython version code
message("Build Auto vectorization helpers")
execute_process(COMMAND "python3" "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/cython_setup.py" "build_ext" "--inplace" WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/" RESULT_VARIABLE ret)
message("build='${ret}'")
if (CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
set(CHECK_VECT_FLAGS "-ftree-vectorize -fsave-optimization-record")
#to process fsave-optimization-record we will need our cython version code
message("Build Auto vectorization helpers")
execute_process(COMMAND "python3" "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/cython_setup.py" "build_ext" "--inplace" WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/" RESULT_VARIABLE ret)
message("build='${ret}'")
#remove fail cases that gcc fails produce sometimes
file(GLOB_RECURSE FAILURE_CASES false ../include/loops/cpu/compilation_units/reduce3*.cpp)
#message("*****${FAILURE_CASES}")
foreach(FL_ITEM ${FAILURE_CASES})
message("Removing failure cases ${FL_ITEM}")
list(REMOVE_ITEM VECT_FILES ${FL_ITEM})
endforeach()
else()
set(CHECK_VECT_FLAGS "-ftree-vectorize -fopt-info-vec-optimized-missed")
#remove fail cases that gcc fails produce sometimes
file(GLOB_RECURSE FAILURE_CASES false ../include/loops/cpu/compilation_units/reduce3*.cpp)
#message("*****${FAILURE_CASES}")
foreach(FL_ITEM ${FAILURE_CASES})
message("Removing failure cases ${FL_ITEM}")
list(REMOVE_ITEM VECT_FILES ${FL_ITEM})
endforeach()
else()
set(CHECK_VECT_FLAGS "-ftree-vectorize -fopt-info-vec-optimized-missed")
endif()
message("CHECK VECTORIZATION ${CHECK_VECT_FLAGS}")
set_source_files_properties( ${VECT_FILES} PROPERTIES COMPILE_FLAGS "${CHECK_VECT_FLAGS}" )
endif()
message("CHECK VECTORIZATION ${CHECK_VECT_FLAGS}")
set_source_files_properties( ${VECT_FILES} PROPERTIES COMPILE_FLAGS "${CHECK_VECT_FLAGS}" )
endif()
endif()
message("CPU BLAS")
@ -373,7 +373,11 @@ elseif(SD_CPU)
foreach (_variableName ${_variableNames})
message(STATUS "${_variableName}=${${_variableName}}")
endforeach()
target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
#This breaks the build. Normally you want to run tests anyways.
if(NOT "$ENV{CLION_IDE}")
target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
endif()
if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}")
message(STATUS "Building minifier...")
@ -382,7 +386,7 @@ elseif(SD_CPU)
endif()
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
message(FATAL_ERROR "You need at least GCC 4.9")
message(FATAL_ERROR "You need at least GCC 4.9")
endif()
# OpenMP works well pretty much only with GCC

View File

@ -26,132 +26,138 @@
#include <ops/declarable/CustomOperations.h>
namespace sd {
namespace ops {
namespace ops {
DECLARE_TYPES(fused_batch_norm) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) {
auto x = INPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
auto scale = INPUT_VARIABLE(1); // [iD]
auto offset = INPUT_VARIABLE(2); // [iD]
auto y = OUTPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
auto batchMean = OUTPUT_VARIABLE(1); // [iD]
auto batchVar = OUTPUT_VARIABLE(2); // [iD]
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
const bool isTraining = (bool)INT_ARG(1);
REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf());
int bS = x->sizeAt(0); // batch size
int iH, iW, iD; // input height, input width, input depth(number of channels)
if(dataFormat) {
iD = x->sizeAt(1);
iH = x->sizeAt(2);
iW = x->sizeAt(3);
}
else {
iD = x->sizeAt(3);
iH = x->sizeAt(1);
iW = x->sizeAt(2);
}
auto xCast = x->cast(sd::DataType::FLOAT32);
REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str());
REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str());
NDArray *mean(nullptr), *variance(nullptr);
if(!isTraining) {
mean = INPUT_VARIABLE(3);
variance = INPUT_VARIABLE(4);
REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(mean).c_str());
REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str());
}
else {
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
std::vector<Nd4jLong> shape = {iD};
mean = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext());
variance = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext());
}
float epsilon;
if(block.getTArguments()->size() > 0) {
epsilon = (float) (T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5);
}
else {
epsilon = 0.001f;
}
const int restSize = x->lengthOf() / iD;
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, sd::DataType::FLOAT32, block.launchContext());
xAffected.assign(xCast);
const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;
const float restSizeInv = 1.0f / restSize;
const float restSizeAdjust = (float)restSize / restSizeMinusOne;
if(isTraining) {
auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
sum *= restSizeInv;
mean->assign(sum);
*batchMean = *mean;
}
else
*batchMean = 0.;
auto xCentered = xAffected - *mean;
xAffected -= *mean;
if(isTraining) {
int power = 2;
xAffected.applyScalar(scalar::Pow, power, xAffected);
auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
sum *= restSizeInv;
variance->assign(sum);
auto varOutput = (*variance) * restSizeAdjust;
batchVar->assign(varOutput);
}
else
*batchVar = 0.;
auto scaledVariance = ((*variance + epsilon).transform(transform::RSqrt) * (*scale)).cast(xAffected.dataType());
auto xScaled1 = xCentered * scaledVariance;
auto xShifted1 = xScaled1 + *offset;
y->assign(xShifted1);
if(isTraining) {
delete mean;
delete variance;
}
return Status::OK();
}
DECLARE_SHAPE_FN(fused_batch_norm) {
auto xShapeInfo = inputShape->at(0);
auto scaleShapeInfo = inputShape->at(1);
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
const int iD = dataFormat ? xShapeInfo[2] : xShapeInfo[4];
REQUIRE_TRUE(scaleShapeInfo[0] == 1 && scaleShapeInfo[1] == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scaleShapeInfo).c_str());
Nd4jLong* outShapeInfo(nullptr), *batchMeanShapeInfo(nullptr), *batchVarShapeInfo(nullptr);
COPY_SHAPE(xShapeInfo, outShapeInfo);
COPY_SHAPE(scaleShapeInfo, batchMeanShapeInfo);
COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo);
return SHAPELIST(CONSTANT(outShapeInfo), CONSTANT(batchMeanShapeInfo), CONSTANT(batchVarShapeInfo));
}
DECLARE_TYPES(fused_batch_norm) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) {
auto x = INPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
auto scale = INPUT_VARIABLE(1); // [iD]
auto offset = INPUT_VARIABLE(2); // [iD]
auto y = OUTPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
auto batchMean = OUTPUT_VARIABLE(1); // [iD]
auto batchVar = OUTPUT_VARIABLE(2); // [iD]
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
const bool isTraining = (bool)INT_ARG(1);
REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf());
int bS = x->sizeAt(0); // batch size
int iH, iW, iD; // input height, input width, input depth(number of channels)
if(dataFormat) {
iD = x->sizeAt(1);
iH = x->sizeAt(2);
iW = x->sizeAt(3);
}
else {
iD = x->sizeAt(3);
iH = x->sizeAt(1);
iW = x->sizeAt(2);
}
REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str());
REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str());
NDArray *mean(nullptr), *variance(nullptr);
if(!isTraining){
mean = INPUT_VARIABLE(3);
variance = INPUT_VARIABLE(4);
REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(mean).c_str());
REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str());
}
else {
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
std::vector<Nd4jLong> shape = {iD};
mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
}
// FIXME: double?
double epsilon;
if(block.getTArguments()->size() > 0)
epsilon = T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5;
else
epsilon = 0.001;
const int restSize = x->lengthOf() / iD;
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext());
xAffected.assign(x);
const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;
// FIXME: float?
const double restSizeInv = 1.0 / restSize;
const double restSizeAdjust = (double)restSize / restSizeMinusOne;
if(isTraining) {
auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
sum *= restSizeInv;
mean->assign(sum);
*batchMean = *mean;
//delete sum;
}
else
*batchMean = 0.;
xAffected -= *mean;
if(isTraining) {
int power = 2;
xAffected.applyScalar(scalar::Pow, power, xAffected);
auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
sum *= restSizeInv;
variance->assign(sum);
*batchVar = (*variance) * restSizeAdjust;
//delete sum;
}
else
*batchVar = 0.;
xAffected *= (*variance + epsilon).transform(transform::RSqrt) * (*scale) + (*offset);
y->assign( xAffected );
if(isTraining) {
delete mean;
delete variance;
}
return Status::OK();
}
DECLARE_SHAPE_FN(fused_batch_norm) {
auto xShapeInfo = inputShape->at(0);
auto scaleShapeInfo = inputShape->at(1);
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
const int iD = dataFormat ? xShapeInfo[2] : xShapeInfo[4];
REQUIRE_TRUE(scaleShapeInfo[0] == 1 && scaleShapeInfo[1] == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scaleShapeInfo).c_str());
Nd4jLong* outShapeInfo(nullptr), *batchMeanShapeInfo(nullptr), *batchVarShapeInfo(nullptr);
COPY_SHAPE(xShapeInfo, outShapeInfo);
COPY_SHAPE(scaleShapeInfo, batchMeanShapeInfo);
COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo);
return SHAPELIST(CONSTANT(outShapeInfo), CONSTANT(batchMeanShapeInfo), CONSTANT(batchVarShapeInfo));
}
}
}
#endif

View File

@ -87,9 +87,12 @@ public class FusedBatchNorm extends DynamicCustomOp {
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
if(!dArguments.isEmpty()) {
return Arrays.asList(dArguments.get(0),dArguments.get(0),dArguments.get(0));
}
return Arrays.asList(outputDataType == null ? DataType.FLOAT : outputDataType,
outputDataType == null ? DataType.FLOAT : outputDataType,
outputDataType == null ? DataType.FLOAT : outputDataType);

View File

@ -69,10 +69,8 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
* the status of the test failing. No tests will run.
*/
public final static List<String> EXECUTE_ONLY_MODELS = Arrays.asList(
"max_pool_with_argmax/int32_int64_padding_SAME",
// "fused_batch_norm/float32_nhwc",
"max_pool_with_argmax/int64_int64_padding_SAME"
// "fused_batch_norm/float16_nhwc",
"fused_batch_norm/float32_nhwc"
// , "fused_batch_norm/float16_nhwc"
);
@ -86,9 +84,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
// Still failing 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: TruncateMod
"truncatemod/.*",
//Still failing as of 2019/09/11 - https://github.com/deeplearning4j/deeplearning4j/issues/6464 - not sure if related to: https://github.com/deeplearning4j/deeplearning4j/issues/6447
"cnn2d_nn/nhwc_b1_k12_s12_d12_SAME",
//2019/09/11 - No tensorflow op found for SparseTensorDenseAdd
// 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: SparseTensorDenseAdd
"confusion/.*",

View File

@ -958,7 +958,7 @@ val fusedBatchnormV1 = TensorflowMappingProcess(
"offset" to "offset","mean" to "mean","variance" to "variance"))),
inputFrameworkOpName = "FusedBatchNorm",
opMappingRegistry = tensorflowOpRegistry,
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon")),
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon","dtype" to "T")),
invertBooleanNumber(mutableMapOf("isTraining" to "is_training")),
stringEqualsRule(outputAttribute = "dataFormat",inputFrameworkAttributeName = "data_format",valueToTest = "NCHW",argumentIndex = 0))
)
@ -971,7 +971,7 @@ val fusedBatchnormV2 = TensorflowMappingProcess(
"offset" to "offset","mean" to "mean","variance" to "variance"))),
inputFrameworkOpName = "FusedBatchNormV2",
opMappingRegistry = tensorflowOpRegistry,
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon")),
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon","dtype" to "T")),
invertBooleanNumber(mutableMapOf("isTraining" to "is_training")),
stringEqualsRule(outputAttribute = "dataFormat",inputFrameworkAttributeName = "data_format",valueToTest = "NCHW",argumentIndex = 0))
)
@ -983,7 +983,7 @@ val fusedBatchnormV3 = TensorflowMappingProcess(
"offset" to "offset","mean" to "mean","variance" to "variance"))),
inputFrameworkOpName = "FusedBatchNormV3",
opMappingRegistry = tensorflowOpRegistry,
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon")),
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon","dtype" to "T")),
invertBooleanNumber(mutableMapOf("isTraining" to "is_training")),
stringEqualsRule(outputAttribute = "dataFormat",inputFrameworkAttributeName = "data_format",valueToTest = "NCHW",argumentIndex = 0))
)

View File

@ -8367,10 +8367,16 @@ mappings {
functionName: "valuemapping"
inputFloatName: "epsilon"
outputDoubleName: "epsilon"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "epsilon"
value: "epsilon"
}
inputToOutput {
key: "dtype"
value: "T"
}
ruleType: "attribute"
inputFrameworkOpName: "FusedBatchNorm"
}
@ -12480,10 +12486,16 @@ mappings {
functionName: "valuemapping"
inputFloatName: "epsilon"
outputDoubleName: "epsilon"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "epsilon"
value: "epsilon"
}
inputToOutput {
key: "dtype"
value: "T"
}
ruleType: "attribute"
inputFrameworkOpName: "FusedBatchNormV3"
}
@ -13056,10 +13068,16 @@ mappings {
functionName: "valuemapping"
inputFloatName: "epsilon"
outputDoubleName: "epsilon"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "epsilon"
value: "epsilon"
}
inputToOutput {
key: "dtype"
value: "T"
}
ruleType: "attribute"
inputFrameworkOpName: "FusedBatchNormV2"
}

View File

@ -90,7 +90,9 @@ class TestTensorflowIR {
//val inputMap = mapOf("image" to Nd4j.ones(1,128,128,4))
val inputMap = emptyMap<String,INDArray>()
val tensorflowIRGraph = TensorflowIRGraph(textGraph,tensorflowOps,tfImporter.registry)
val outputList = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }.toSet()
val outputList = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }.toMutableSet()
outputList.add("FusedBatchNormV3:1")
outputList.add("FusedBatchNormV3:2")
val tfGraphRunner = TensorflowIRGraphRunner(tensorflowIRGraph, inputMap.keys.toList(), outputList.toList())
val importedGraph = TFGraphMapper.importGraph(textGraph)
val graph = tfImporter.importFromGraph(textGraph,inputMap)
@ -104,7 +106,7 @@ class TestTensorflowIR {
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }
val skipValidation = setOf("parallel_stack/ExpandDims/dim")
//assertEquals(output.keys,output2.keys)
val notEquals = HashSet<String>()
/* val notEquals = HashSet<String>()
names.forEach {
val value = output[it]
val value2 = output2[it]
@ -115,9 +117,9 @@ class TestTensorflowIR {
val newVar = graph.variables[it]
notEquals.add(it)
}
}
}*/
println(notEquals)
//println(notEquals)
// assertEquals(output,output2)
//assertEquals(tfOutput,output)

View File

@ -8367,10 +8367,16 @@ mappings {
functionName: "valuemapping"
inputFloatName: "epsilon"
outputDoubleName: "epsilon"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "epsilon"
value: "epsilon"
}
inputToOutput {
key: "dtype"
value: "T"
}
ruleType: "attribute"
inputFrameworkOpName: "FusedBatchNorm"
}
@ -12480,10 +12486,16 @@ mappings {
functionName: "valuemapping"
inputFloatName: "epsilon"
outputDoubleName: "epsilon"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "epsilon"
value: "epsilon"
}
inputToOutput {
key: "dtype"
value: "T"
}
ruleType: "attribute"
inputFrameworkOpName: "FusedBatchNormV3"
}
@ -13056,10 +13068,16 @@ mappings {
functionName: "valuemapping"
inputFloatName: "epsilon"
outputDoubleName: "epsilon"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "epsilon"
value: "epsilon"
}
inputToOutput {
key: "dtype"
value: "T"
}
ruleType: "attribute"
inputFrameworkOpName: "FusedBatchNormV2"
}