Add new clion rules, fix batch norml
parent
968eaad2dd
commit
5bd386a4f9
|
@ -51,9 +51,9 @@ endif()
|
||||||
|
|
||||||
if(WIN32 AND NOT ANDROID)
|
if(WIN32 AND NOT ANDROID)
|
||||||
get_property(dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES)
|
get_property(dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES)
|
||||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wa,-mbig-obj")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wa,-mbig-obj")
|
||||||
endif()
|
endif()
|
||||||
foreach(dir ${dirs})
|
foreach(dir ${dirs})
|
||||||
message(STATUS "dir='${dir}'")
|
message(STATUS "dir='${dir}'")
|
||||||
endforeach()
|
endforeach()
|
||||||
|
@ -161,8 +161,8 @@ if(SD_CUDA)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (CUDA_FOUND)
|
if (CUDA_FOUND)
|
||||||
message("CUDA include directory: ${CUDA_INCLUDE_DIRS}")
|
message("CUDA include directory: ${CUDA_INCLUDE_DIRS}")
|
||||||
include_directories(${CUDA_INCLUDE_DIRS})
|
include_directories(${CUDA_INCLUDE_DIRS})
|
||||||
message("CUDA found!")
|
message("CUDA found!")
|
||||||
if ("${SD_EXPERIMENTAL}" STREQUAL "yes")
|
if ("${SD_EXPERIMENTAL}" STREQUAL "yes")
|
||||||
message("Experimental mode ENABLED")
|
message("Experimental mode ENABLED")
|
||||||
|
@ -181,7 +181,7 @@ if(SD_CUDA)
|
||||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=-fPIC")
|
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=-fPIC")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
string( TOLOWER "${COMPUTE}" COMPUTE_CMP )
|
string( TOLOWER "${COMPUTE}" COMPUTE_CMP )
|
||||||
if ("${COMPUTE_CMP}" STREQUAL "all")
|
if ("${COMPUTE_CMP}" STREQUAL "all")
|
||||||
CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "Common")
|
CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "Common")
|
||||||
elseif("${COMPUTE_CMP}" STREQUAL "auto")
|
elseif("${COMPUTE_CMP}" STREQUAL "auto")
|
||||||
|
@ -197,9 +197,9 @@ if(SD_CUDA)
|
||||||
endif()
|
endif()
|
||||||
# list to spaces
|
# list to spaces
|
||||||
string (REPLACE ";" " " CUDA_ARCH_FLAGS "${CUDA_ARCH_FLAGS}")
|
string (REPLACE ";" " " CUDA_ARCH_FLAGS "${CUDA_ARCH_FLAGS}")
|
||||||
|
|
||||||
set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_VERSION_MAJOR=${CUDA_VERSION_MAJOR} ${EXPM} -w --cudart=static --expt-extended-lambda -Xfatbin -compress-all ${CUDA_ARCH_FLAGS}")
|
set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_VERSION_MAJOR=${CUDA_VERSION_MAJOR} ${EXPM} -w --cudart=static --expt-extended-lambda -Xfatbin -compress-all ${CUDA_ARCH_FLAGS}")
|
||||||
|
|
||||||
file(GLOB_RECURSE PERF_SOURCES false ../include/performance/*.cpp ../include/performance/*.h)
|
file(GLOB_RECURSE PERF_SOURCES false ../include/performance/*.cpp ../include/performance/*.h)
|
||||||
file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../include/exceptions/*.cpp ../include/exceptions/*.h)
|
file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../include/exceptions/*.cpp ../include/exceptions/*.h)
|
||||||
file(GLOB_RECURSE EXEC_SOURCES false ../include/execution/impl/*.cpp ../include/execution/*.cu ../include/execution/*.h)
|
file(GLOB_RECURSE EXEC_SOURCES false ../include/execution/impl/*.cpp ../include/execution/*.cu ../include/execution/*.h)
|
||||||
|
@ -218,23 +218,23 @@ if(SD_CUDA)
|
||||||
|
|
||||||
|
|
||||||
file(GLOB_RECURSE COMPILATION_UNITS false ../include/loops/cuda/compilation_units/*.cu.in
|
file(GLOB_RECURSE COMPILATION_UNITS false ../include/loops/cuda/compilation_units/*.cu.in
|
||||||
../include/ops/impl/compilation_units/*.cpp.in)
|
../include/ops/impl/compilation_units/*.cpp.in)
|
||||||
|
|
||||||
foreach(FL_ITEM ${COMPILATION_UNITS})
|
foreach(FL_ITEM ${COMPILATION_UNITS})
|
||||||
genCompilation(FL_ITEM)
|
genCompilation(FL_ITEM)
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
if (HAVE_CUDNN)
|
if (HAVE_CUDNN)
|
||||||
message("cuDNN included")
|
message("cuDNN included")
|
||||||
file(GLOB_RECURSE CUSTOMOPS_CUDNN_SOURCES false ../include/ops/declarable/platform/cudnn/*.cu)
|
file(GLOB_RECURSE CUSTOMOPS_CUDNN_SOURCES false ../include/ops/declarable/platform/cudnn/*.cu)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_library(samediff_obj OBJECT ${LOOPS_SOURCES_CUDA} ${LEGACY_SOURCES}
|
add_library(samediff_obj OBJECT ${LOOPS_SOURCES_CUDA} ${LEGACY_SOURCES}
|
||||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||||
${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
|
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
|
||||||
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
||||||
)
|
)
|
||||||
|
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
message("MSVC runtime for library: ${MSVC_RT_LIB}")
|
message("MSVC runtime for library: ${MSVC_RT_LIB}")
|
||||||
|
@ -266,10 +266,10 @@ if(SD_CUDA)
|
||||||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj /std:c++14")
|
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj /std:c++14")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
target_link_libraries(${SD_LIBRARY_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN})
|
target_link_libraries(${SD_LIBRARY_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN})
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda)
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda)
|
||||||
|
|
||||||
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
|
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
|
||||||
endif(CUDA_FOUND)
|
endif(CUDA_FOUND)
|
||||||
elseif(SD_CPU)
|
elseif(SD_CPU)
|
||||||
|
|
||||||
|
@ -295,13 +295,13 @@ elseif(SD_CPU)
|
||||||
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
||||||
|
|
||||||
|
|
||||||
file(GLOB_RECURSE COMPILATION_UNITS false ../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in
|
file(GLOB_RECURSE COMPILATION_UNITS false ../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in
|
||||||
../include/loops/cpu/compilation_units/*.cpp.in ../include/helpers/cpu/loops/*.cpp.in
|
../include/loops/cpu/compilation_units/*.cpp.in ../include/helpers/cpu/loops/*.cpp.in
|
||||||
../include/ops/impl/compilation_units/*.cpp.in)
|
../include/ops/impl/compilation_units/*.cpp.in)
|
||||||
|
|
||||||
foreach(FL_ITEM ${COMPILATION_UNITS})
|
foreach(FL_ITEM ${COMPILATION_UNITS})
|
||||||
genCompilation(FL_ITEM)
|
genCompilation(FL_ITEM)
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
if (SD_X86_BUILD)
|
if (SD_X86_BUILD)
|
||||||
# we disable platform optimizations for certains files for linux/macos
|
# we disable platform optimizations for certains files for linux/macos
|
||||||
|
@ -312,36 +312,36 @@ elseif(SD_CPU)
|
||||||
|
|
||||||
|
|
||||||
if(SD_CHECK_VECTORIZATION)
|
if(SD_CHECK_VECTORIZATION)
|
||||||
set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES})
|
set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES})
|
||||||
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||||
|
|
||||||
if (CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
|
||||||
set(CHECK_VECT_FLAGS "-ftree-vectorize -fsave-optimization-record")
|
|
||||||
#to process fsave-optimization-record we will need our cython version code
|
|
||||||
message("Build Auto vectorization helpers")
|
|
||||||
execute_process(COMMAND "python3" "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/cython_setup.py" "build_ext" "--inplace" WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/" RESULT_VARIABLE ret)
|
|
||||||
message("build='${ret}'")
|
|
||||||
|
|
||||||
#remove fail cases that gcc fails produce sometimes
|
if (CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
||||||
file(GLOB_RECURSE FAILURE_CASES false ../include/loops/cpu/compilation_units/reduce3*.cpp)
|
set(CHECK_VECT_FLAGS "-ftree-vectorize -fsave-optimization-record")
|
||||||
#message("*****${FAILURE_CASES}")
|
#to process fsave-optimization-record we will need our cython version code
|
||||||
foreach(FL_ITEM ${FAILURE_CASES})
|
message("Build Auto vectorization helpers")
|
||||||
message("Removing failure cases ${FL_ITEM}")
|
execute_process(COMMAND "python3" "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/cython_setup.py" "build_ext" "--inplace" WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/" RESULT_VARIABLE ret)
|
||||||
list(REMOVE_ITEM VECT_FILES ${FL_ITEM})
|
message("build='${ret}'")
|
||||||
endforeach()
|
|
||||||
else()
|
#remove fail cases that gcc fails produce sometimes
|
||||||
set(CHECK_VECT_FLAGS "-ftree-vectorize -fopt-info-vec-optimized-missed")
|
file(GLOB_RECURSE FAILURE_CASES false ../include/loops/cpu/compilation_units/reduce3*.cpp)
|
||||||
|
#message("*****${FAILURE_CASES}")
|
||||||
|
foreach(FL_ITEM ${FAILURE_CASES})
|
||||||
|
message("Removing failure cases ${FL_ITEM}")
|
||||||
|
list(REMOVE_ITEM VECT_FILES ${FL_ITEM})
|
||||||
|
endforeach()
|
||||||
|
else()
|
||||||
|
set(CHECK_VECT_FLAGS "-ftree-vectorize -fopt-info-vec-optimized-missed")
|
||||||
|
endif()
|
||||||
|
message("CHECK VECTORIZATION ${CHECK_VECT_FLAGS}")
|
||||||
|
set_source_files_properties( ${VECT_FILES} PROPERTIES COMPILE_FLAGS "${CHECK_VECT_FLAGS}" )
|
||||||
endif()
|
endif()
|
||||||
message("CHECK VECTORIZATION ${CHECK_VECT_FLAGS}")
|
endif()
|
||||||
set_source_files_properties( ${VECT_FILES} PROPERTIES COMPILE_FLAGS "${CHECK_VECT_FLAGS}" )
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
message("CPU BLAS")
|
message("CPU BLAS")
|
||||||
add_definitions(-D__CPUBLAS__=true)
|
add_definitions(-D__CPUBLAS__=true)
|
||||||
add_library(samediff_obj OBJECT ${LEGACY_SOURCES}
|
add_library(samediff_obj OBJECT ${LEGACY_SOURCES}
|
||||||
${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
|
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
|
||||||
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
|
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
|
||||||
if(IOS)
|
if(IOS)
|
||||||
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:samediff_obj>)
|
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:samediff_obj>)
|
||||||
|
@ -373,7 +373,11 @@ elseif(SD_CPU)
|
||||||
foreach (_variableName ${_variableNames})
|
foreach (_variableName ${_variableNames})
|
||||||
message(STATUS "${_variableName}=${${_variableName}}")
|
message(STATUS "${_variableName}=${${_variableName}}")
|
||||||
endforeach()
|
endforeach()
|
||||||
target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
|
||||||
|
#This breaks the build. Normally you want to run tests anyways.
|
||||||
|
if(NOT "$ENV{CLION_IDE}")
|
||||||
|
target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||||
|
endif()
|
||||||
|
|
||||||
if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}")
|
if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}")
|
||||||
message(STATUS "Building minifier...")
|
message(STATUS "Building minifier...")
|
||||||
|
@ -382,7 +386,7 @@ elseif(SD_CPU)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
||||||
message(FATAL_ERROR "You need at least GCC 4.9")
|
message(FATAL_ERROR "You need at least GCC 4.9")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# OpenMP works well pretty much only with GCC
|
# OpenMP works well pretty much only with GCC
|
||||||
|
|
|
@ -26,132 +26,138 @@
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
DECLARE_TYPES(fused_batch_norm) {
|
DECLARE_TYPES(fused_batch_norm) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(sd::DataType::ANY)
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
|
}
|
||||||
|
|
||||||
|
CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) {
|
||||||
|
auto x = INPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
|
||||||
|
auto scale = INPUT_VARIABLE(1); // [iD]
|
||||||
|
auto offset = INPUT_VARIABLE(2); // [iD]
|
||||||
|
|
||||||
|
auto y = OUTPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
|
||||||
|
auto batchMean = OUTPUT_VARIABLE(1); // [iD]
|
||||||
|
auto batchVar = OUTPUT_VARIABLE(2); // [iD]
|
||||||
|
|
||||||
|
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
|
||||||
|
const bool isTraining = (bool)INT_ARG(1);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf());
|
||||||
|
|
||||||
|
int bS = x->sizeAt(0); // batch size
|
||||||
|
int iH, iW, iD; // input height, input width, input depth(number of channels)
|
||||||
|
if(dataFormat) {
|
||||||
|
iD = x->sizeAt(1);
|
||||||
|
iH = x->sizeAt(2);
|
||||||
|
iW = x->sizeAt(3);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
iD = x->sizeAt(3);
|
||||||
|
iH = x->sizeAt(1);
|
||||||
|
iW = x->sizeAt(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto xCast = x->cast(sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
|
||||||
|
REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str());
|
||||||
|
REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str());
|
||||||
|
|
||||||
|
NDArray *mean(nullptr), *variance(nullptr);
|
||||||
|
if(!isTraining) {
|
||||||
|
mean = INPUT_VARIABLE(3);
|
||||||
|
variance = INPUT_VARIABLE(4);
|
||||||
|
REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(mean).c_str());
|
||||||
|
REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
|
||||||
|
std::vector<Nd4jLong> shape = {iD};
|
||||||
|
mean = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext());
|
||||||
|
variance = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
float epsilon;
|
||||||
|
if(block.getTArguments()->size() > 0) {
|
||||||
|
epsilon = (float) (T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
epsilon = 0.001f;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int restSize = x->lengthOf() / iD;
|
||||||
|
|
||||||
|
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, sd::DataType::FLOAT32, block.launchContext());
|
||||||
|
xAffected.assign(xCast);
|
||||||
|
|
||||||
|
const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;
|
||||||
|
const float restSizeInv = 1.0f / restSize;
|
||||||
|
const float restSizeAdjust = (float)restSize / restSizeMinusOne;
|
||||||
|
|
||||||
|
if(isTraining) {
|
||||||
|
auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
|
||||||
|
sum *= restSizeInv;
|
||||||
|
mean->assign(sum);
|
||||||
|
*batchMean = *mean;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
*batchMean = 0.;
|
||||||
|
|
||||||
|
auto xCentered = xAffected - *mean;
|
||||||
|
xAffected -= *mean;
|
||||||
|
|
||||||
|
if(isTraining) {
|
||||||
|
int power = 2;
|
||||||
|
xAffected.applyScalar(scalar::Pow, power, xAffected);
|
||||||
|
auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
|
||||||
|
sum *= restSizeInv;
|
||||||
|
variance->assign(sum);
|
||||||
|
auto varOutput = (*variance) * restSizeAdjust;
|
||||||
|
batchVar->assign(varOutput);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
*batchVar = 0.;
|
||||||
|
|
||||||
|
auto scaledVariance = ((*variance + epsilon).transform(transform::RSqrt) * (*scale)).cast(xAffected.dataType());
|
||||||
|
auto xScaled1 = xCentered * scaledVariance;
|
||||||
|
auto xShifted1 = xScaled1 + *offset;
|
||||||
|
|
||||||
|
y->assign(xShifted1);
|
||||||
|
|
||||||
|
if(isTraining) {
|
||||||
|
delete mean;
|
||||||
|
delete variance;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(fused_batch_norm) {
|
||||||
|
auto xShapeInfo = inputShape->at(0);
|
||||||
|
auto scaleShapeInfo = inputShape->at(1);
|
||||||
|
|
||||||
|
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
|
||||||
|
const int iD = dataFormat ? xShapeInfo[2] : xShapeInfo[4];
|
||||||
|
|
||||||
|
REQUIRE_TRUE(scaleShapeInfo[0] == 1 && scaleShapeInfo[1] == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scaleShapeInfo).c_str());
|
||||||
|
|
||||||
|
Nd4jLong* outShapeInfo(nullptr), *batchMeanShapeInfo(nullptr), *batchVarShapeInfo(nullptr);
|
||||||
|
|
||||||
|
COPY_SHAPE(xShapeInfo, outShapeInfo);
|
||||||
|
COPY_SHAPE(scaleShapeInfo, batchMeanShapeInfo);
|
||||||
|
COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo);
|
||||||
|
|
||||||
|
return SHAPELIST(CONSTANT(outShapeInfo), CONSTANT(batchMeanShapeInfo), CONSTANT(batchVarShapeInfo));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) {
|
|
||||||
auto x = INPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
|
|
||||||
auto scale = INPUT_VARIABLE(1); // [iD]
|
|
||||||
auto offset = INPUT_VARIABLE(2); // [iD]
|
|
||||||
|
|
||||||
auto y = OUTPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
|
|
||||||
auto batchMean = OUTPUT_VARIABLE(1); // [iD]
|
|
||||||
auto batchVar = OUTPUT_VARIABLE(2); // [iD]
|
|
||||||
|
|
||||||
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
|
|
||||||
const bool isTraining = (bool)INT_ARG(1);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf());
|
|
||||||
|
|
||||||
int bS = x->sizeAt(0); // batch size
|
|
||||||
int iH, iW, iD; // input height, input width, input depth(number of channels)
|
|
||||||
if(dataFormat) {
|
|
||||||
iD = x->sizeAt(1);
|
|
||||||
iH = x->sizeAt(2);
|
|
||||||
iW = x->sizeAt(3);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
iD = x->sizeAt(3);
|
|
||||||
iH = x->sizeAt(1);
|
|
||||||
iW = x->sizeAt(2);
|
|
||||||
}
|
|
||||||
|
|
||||||
REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str());
|
|
||||||
REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str());
|
|
||||||
|
|
||||||
NDArray *mean(nullptr), *variance(nullptr);
|
|
||||||
if(!isTraining){
|
|
||||||
mean = INPUT_VARIABLE(3);
|
|
||||||
variance = INPUT_VARIABLE(4);
|
|
||||||
REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(mean).c_str());
|
|
||||||
REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str());
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
|
|
||||||
std::vector<Nd4jLong> shape = {iD};
|
|
||||||
mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
|
|
||||||
variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
|
|
||||||
}
|
|
||||||
|
|
||||||
// FIXME: double?
|
|
||||||
double epsilon;
|
|
||||||
if(block.getTArguments()->size() > 0)
|
|
||||||
epsilon = T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5;
|
|
||||||
else
|
|
||||||
epsilon = 0.001;
|
|
||||||
|
|
||||||
const int restSize = x->lengthOf() / iD;
|
|
||||||
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext());
|
|
||||||
xAffected.assign(x);
|
|
||||||
|
|
||||||
const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;
|
|
||||||
// FIXME: float?
|
|
||||||
const double restSizeInv = 1.0 / restSize;
|
|
||||||
const double restSizeAdjust = (double)restSize / restSizeMinusOne;
|
|
||||||
|
|
||||||
if(isTraining) {
|
|
||||||
auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
|
|
||||||
sum *= restSizeInv;
|
|
||||||
mean->assign(sum);
|
|
||||||
*batchMean = *mean;
|
|
||||||
//delete sum;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
*batchMean = 0.;
|
|
||||||
|
|
||||||
xAffected -= *mean;
|
|
||||||
|
|
||||||
if(isTraining) {
|
|
||||||
int power = 2;
|
|
||||||
xAffected.applyScalar(scalar::Pow, power, xAffected);
|
|
||||||
auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
|
|
||||||
sum *= restSizeInv;
|
|
||||||
variance->assign(sum);
|
|
||||||
*batchVar = (*variance) * restSizeAdjust;
|
|
||||||
//delete sum;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
*batchVar = 0.;
|
|
||||||
xAffected *= (*variance + epsilon).transform(transform::RSqrt) * (*scale) + (*offset);
|
|
||||||
y->assign( xAffected );
|
|
||||||
|
|
||||||
if(isTraining) {
|
|
||||||
delete mean;
|
|
||||||
delete variance;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(fused_batch_norm) {
|
|
||||||
auto xShapeInfo = inputShape->at(0);
|
|
||||||
auto scaleShapeInfo = inputShape->at(1);
|
|
||||||
|
|
||||||
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
|
|
||||||
const int iD = dataFormat ? xShapeInfo[2] : xShapeInfo[4];
|
|
||||||
|
|
||||||
REQUIRE_TRUE(scaleShapeInfo[0] == 1 && scaleShapeInfo[1] == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scaleShapeInfo).c_str());
|
|
||||||
|
|
||||||
Nd4jLong* outShapeInfo(nullptr), *batchMeanShapeInfo(nullptr), *batchVarShapeInfo(nullptr);
|
|
||||||
|
|
||||||
COPY_SHAPE(xShapeInfo, outShapeInfo);
|
|
||||||
COPY_SHAPE(scaleShapeInfo, batchMeanShapeInfo);
|
|
||||||
COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo);
|
|
||||||
|
|
||||||
return SHAPELIST(CONSTANT(outShapeInfo), CONSTANT(batchMeanShapeInfo), CONSTANT(batchVarShapeInfo));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -87,9 +87,12 @@ public class FusedBatchNorm extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||||
int n = args().length;
|
int n = args().length;
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
|
if(!dArguments.isEmpty()) {
|
||||||
|
return Arrays.asList(dArguments.get(0),dArguments.get(0),dArguments.get(0));
|
||||||
|
}
|
||||||
return Arrays.asList(outputDataType == null ? DataType.FLOAT : outputDataType,
|
return Arrays.asList(outputDataType == null ? DataType.FLOAT : outputDataType,
|
||||||
outputDataType == null ? DataType.FLOAT : outputDataType,
|
outputDataType == null ? DataType.FLOAT : outputDataType,
|
||||||
outputDataType == null ? DataType.FLOAT : outputDataType);
|
outputDataType == null ? DataType.FLOAT : outputDataType);
|
||||||
|
|
|
@ -69,10 +69,8 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
* the status of the test failing. No tests will run.
|
* the status of the test failing. No tests will run.
|
||||||
*/
|
*/
|
||||||
public final static List<String> EXECUTE_ONLY_MODELS = Arrays.asList(
|
public final static List<String> EXECUTE_ONLY_MODELS = Arrays.asList(
|
||||||
"max_pool_with_argmax/int32_int64_padding_SAME",
|
"fused_batch_norm/float32_nhwc"
|
||||||
// "fused_batch_norm/float32_nhwc",
|
// , "fused_batch_norm/float16_nhwc"
|
||||||
"max_pool_with_argmax/int64_int64_padding_SAME"
|
|
||||||
// "fused_batch_norm/float16_nhwc",
|
|
||||||
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -86,9 +84,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
// Still failing 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: TruncateMod
|
// Still failing 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: TruncateMod
|
||||||
"truncatemod/.*",
|
"truncatemod/.*",
|
||||||
|
|
||||||
//Still failing as of 2019/09/11 - https://github.com/deeplearning4j/deeplearning4j/issues/6464 - not sure if related to: https://github.com/deeplearning4j/deeplearning4j/issues/6447
|
|
||||||
"cnn2d_nn/nhwc_b1_k12_s12_d12_SAME",
|
|
||||||
|
|
||||||
//2019/09/11 - No tensorflow op found for SparseTensorDenseAdd
|
//2019/09/11 - No tensorflow op found for SparseTensorDenseAdd
|
||||||
// 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: SparseTensorDenseAdd
|
// 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: SparseTensorDenseAdd
|
||||||
"confusion/.*",
|
"confusion/.*",
|
||||||
|
|
|
@ -958,7 +958,7 @@ val fusedBatchnormV1 = TensorflowMappingProcess(
|
||||||
"offset" to "offset","mean" to "mean","variance" to "variance"))),
|
"offset" to "offset","mean" to "mean","variance" to "variance"))),
|
||||||
inputFrameworkOpName = "FusedBatchNorm",
|
inputFrameworkOpName = "FusedBatchNorm",
|
||||||
opMappingRegistry = tensorflowOpRegistry,
|
opMappingRegistry = tensorflowOpRegistry,
|
||||||
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon")),
|
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon","dtype" to "T")),
|
||||||
invertBooleanNumber(mutableMapOf("isTraining" to "is_training")),
|
invertBooleanNumber(mutableMapOf("isTraining" to "is_training")),
|
||||||
stringEqualsRule(outputAttribute = "dataFormat",inputFrameworkAttributeName = "data_format",valueToTest = "NCHW",argumentIndex = 0))
|
stringEqualsRule(outputAttribute = "dataFormat",inputFrameworkAttributeName = "data_format",valueToTest = "NCHW",argumentIndex = 0))
|
||||||
)
|
)
|
||||||
|
@ -971,7 +971,7 @@ val fusedBatchnormV2 = TensorflowMappingProcess(
|
||||||
"offset" to "offset","mean" to "mean","variance" to "variance"))),
|
"offset" to "offset","mean" to "mean","variance" to "variance"))),
|
||||||
inputFrameworkOpName = "FusedBatchNormV2",
|
inputFrameworkOpName = "FusedBatchNormV2",
|
||||||
opMappingRegistry = tensorflowOpRegistry,
|
opMappingRegistry = tensorflowOpRegistry,
|
||||||
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon")),
|
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon","dtype" to "T")),
|
||||||
invertBooleanNumber(mutableMapOf("isTraining" to "is_training")),
|
invertBooleanNumber(mutableMapOf("isTraining" to "is_training")),
|
||||||
stringEqualsRule(outputAttribute = "dataFormat",inputFrameworkAttributeName = "data_format",valueToTest = "NCHW",argumentIndex = 0))
|
stringEqualsRule(outputAttribute = "dataFormat",inputFrameworkAttributeName = "data_format",valueToTest = "NCHW",argumentIndex = 0))
|
||||||
)
|
)
|
||||||
|
@ -983,7 +983,7 @@ val fusedBatchnormV3 = TensorflowMappingProcess(
|
||||||
"offset" to "offset","mean" to "mean","variance" to "variance"))),
|
"offset" to "offset","mean" to "mean","variance" to "variance"))),
|
||||||
inputFrameworkOpName = "FusedBatchNormV3",
|
inputFrameworkOpName = "FusedBatchNormV3",
|
||||||
opMappingRegistry = tensorflowOpRegistry,
|
opMappingRegistry = tensorflowOpRegistry,
|
||||||
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon")),
|
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon","dtype" to "T")),
|
||||||
invertBooleanNumber(mutableMapOf("isTraining" to "is_training")),
|
invertBooleanNumber(mutableMapOf("isTraining" to "is_training")),
|
||||||
stringEqualsRule(outputAttribute = "dataFormat",inputFrameworkAttributeName = "data_format",valueToTest = "NCHW",argumentIndex = 0))
|
stringEqualsRule(outputAttribute = "dataFormat",inputFrameworkAttributeName = "data_format",valueToTest = "NCHW",argumentIndex = 0))
|
||||||
)
|
)
|
||||||
|
|
|
@ -8367,10 +8367,16 @@ mappings {
|
||||||
functionName: "valuemapping"
|
functionName: "valuemapping"
|
||||||
inputFloatName: "epsilon"
|
inputFloatName: "epsilon"
|
||||||
outputDoubleName: "epsilon"
|
outputDoubleName: "epsilon"
|
||||||
|
inputDataTypeName: "T"
|
||||||
|
outputDataTypeName: "dtype"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "epsilon"
|
key: "epsilon"
|
||||||
value: "epsilon"
|
value: "epsilon"
|
||||||
}
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "dtype"
|
||||||
|
value: "T"
|
||||||
|
}
|
||||||
ruleType: "attribute"
|
ruleType: "attribute"
|
||||||
inputFrameworkOpName: "FusedBatchNorm"
|
inputFrameworkOpName: "FusedBatchNorm"
|
||||||
}
|
}
|
||||||
|
@ -12480,10 +12486,16 @@ mappings {
|
||||||
functionName: "valuemapping"
|
functionName: "valuemapping"
|
||||||
inputFloatName: "epsilon"
|
inputFloatName: "epsilon"
|
||||||
outputDoubleName: "epsilon"
|
outputDoubleName: "epsilon"
|
||||||
|
inputDataTypeName: "T"
|
||||||
|
outputDataTypeName: "dtype"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "epsilon"
|
key: "epsilon"
|
||||||
value: "epsilon"
|
value: "epsilon"
|
||||||
}
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "dtype"
|
||||||
|
value: "T"
|
||||||
|
}
|
||||||
ruleType: "attribute"
|
ruleType: "attribute"
|
||||||
inputFrameworkOpName: "FusedBatchNormV3"
|
inputFrameworkOpName: "FusedBatchNormV3"
|
||||||
}
|
}
|
||||||
|
@ -13056,10 +13068,16 @@ mappings {
|
||||||
functionName: "valuemapping"
|
functionName: "valuemapping"
|
||||||
inputFloatName: "epsilon"
|
inputFloatName: "epsilon"
|
||||||
outputDoubleName: "epsilon"
|
outputDoubleName: "epsilon"
|
||||||
|
inputDataTypeName: "T"
|
||||||
|
outputDataTypeName: "dtype"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "epsilon"
|
key: "epsilon"
|
||||||
value: "epsilon"
|
value: "epsilon"
|
||||||
}
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "dtype"
|
||||||
|
value: "T"
|
||||||
|
}
|
||||||
ruleType: "attribute"
|
ruleType: "attribute"
|
||||||
inputFrameworkOpName: "FusedBatchNormV2"
|
inputFrameworkOpName: "FusedBatchNormV2"
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,7 +90,9 @@ class TestTensorflowIR {
|
||||||
//val inputMap = mapOf("image" to Nd4j.ones(1,128,128,4))
|
//val inputMap = mapOf("image" to Nd4j.ones(1,128,128,4))
|
||||||
val inputMap = emptyMap<String,INDArray>()
|
val inputMap = emptyMap<String,INDArray>()
|
||||||
val tensorflowIRGraph = TensorflowIRGraph(textGraph,tensorflowOps,tfImporter.registry)
|
val tensorflowIRGraph = TensorflowIRGraph(textGraph,tensorflowOps,tfImporter.registry)
|
||||||
val outputList = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }.toSet()
|
val outputList = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }.toMutableSet()
|
||||||
|
outputList.add("FusedBatchNormV3:1")
|
||||||
|
outputList.add("FusedBatchNormV3:2")
|
||||||
val tfGraphRunner = TensorflowIRGraphRunner(tensorflowIRGraph, inputMap.keys.toList(), outputList.toList())
|
val tfGraphRunner = TensorflowIRGraphRunner(tensorflowIRGraph, inputMap.keys.toList(), outputList.toList())
|
||||||
val importedGraph = TFGraphMapper.importGraph(textGraph)
|
val importedGraph = TFGraphMapper.importGraph(textGraph)
|
||||||
val graph = tfImporter.importFromGraph(textGraph,inputMap)
|
val graph = tfImporter.importFromGraph(textGraph,inputMap)
|
||||||
|
@ -104,7 +106,7 @@ class TestTensorflowIR {
|
||||||
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }
|
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }
|
||||||
val skipValidation = setOf("parallel_stack/ExpandDims/dim")
|
val skipValidation = setOf("parallel_stack/ExpandDims/dim")
|
||||||
//assertEquals(output.keys,output2.keys)
|
//assertEquals(output.keys,output2.keys)
|
||||||
val notEquals = HashSet<String>()
|
/* val notEquals = HashSet<String>()
|
||||||
names.forEach {
|
names.forEach {
|
||||||
val value = output[it]
|
val value = output[it]
|
||||||
val value2 = output2[it]
|
val value2 = output2[it]
|
||||||
|
@ -115,9 +117,9 @@ class TestTensorflowIR {
|
||||||
val newVar = graph.variables[it]
|
val newVar = graph.variables[it]
|
||||||
notEquals.add(it)
|
notEquals.add(it)
|
||||||
}
|
}
|
||||||
}
|
}*/
|
||||||
|
|
||||||
println(notEquals)
|
//println(notEquals)
|
||||||
|
|
||||||
// assertEquals(output,output2)
|
// assertEquals(output,output2)
|
||||||
//assertEquals(tfOutput,output)
|
//assertEquals(tfOutput,output)
|
||||||
|
|
|
@ -8367,10 +8367,16 @@ mappings {
|
||||||
functionName: "valuemapping"
|
functionName: "valuemapping"
|
||||||
inputFloatName: "epsilon"
|
inputFloatName: "epsilon"
|
||||||
outputDoubleName: "epsilon"
|
outputDoubleName: "epsilon"
|
||||||
|
inputDataTypeName: "T"
|
||||||
|
outputDataTypeName: "dtype"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "epsilon"
|
key: "epsilon"
|
||||||
value: "epsilon"
|
value: "epsilon"
|
||||||
}
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "dtype"
|
||||||
|
value: "T"
|
||||||
|
}
|
||||||
ruleType: "attribute"
|
ruleType: "attribute"
|
||||||
inputFrameworkOpName: "FusedBatchNorm"
|
inputFrameworkOpName: "FusedBatchNorm"
|
||||||
}
|
}
|
||||||
|
@ -12480,10 +12486,16 @@ mappings {
|
||||||
functionName: "valuemapping"
|
functionName: "valuemapping"
|
||||||
inputFloatName: "epsilon"
|
inputFloatName: "epsilon"
|
||||||
outputDoubleName: "epsilon"
|
outputDoubleName: "epsilon"
|
||||||
|
inputDataTypeName: "T"
|
||||||
|
outputDataTypeName: "dtype"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "epsilon"
|
key: "epsilon"
|
||||||
value: "epsilon"
|
value: "epsilon"
|
||||||
}
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "dtype"
|
||||||
|
value: "T"
|
||||||
|
}
|
||||||
ruleType: "attribute"
|
ruleType: "attribute"
|
||||||
inputFrameworkOpName: "FusedBatchNormV3"
|
inputFrameworkOpName: "FusedBatchNormV3"
|
||||||
}
|
}
|
||||||
|
@ -13056,10 +13068,16 @@ mappings {
|
||||||
functionName: "valuemapping"
|
functionName: "valuemapping"
|
||||||
inputFloatName: "epsilon"
|
inputFloatName: "epsilon"
|
||||||
outputDoubleName: "epsilon"
|
outputDoubleName: "epsilon"
|
||||||
|
inputDataTypeName: "T"
|
||||||
|
outputDataTypeName: "dtype"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "epsilon"
|
key: "epsilon"
|
||||||
value: "epsilon"
|
value: "epsilon"
|
||||||
}
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "dtype"
|
||||||
|
value: "T"
|
||||||
|
}
|
||||||
ruleType: "attribute"
|
ruleType: "attribute"
|
||||||
inputFrameworkOpName: "FusedBatchNormV2"
|
inputFrameworkOpName: "FusedBatchNormV2"
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue