Add new clion rules, fix batch norml
parent
968eaad2dd
commit
5bd386a4f9
|
@ -51,9 +51,9 @@ endif()
|
|||
|
||||
if(WIN32 AND NOT ANDROID)
|
||||
get_property(dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES)
|
||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wa,-mbig-obj")
|
||||
endif()
|
||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wa,-mbig-obj")
|
||||
endif()
|
||||
foreach(dir ${dirs})
|
||||
message(STATUS "dir='${dir}'")
|
||||
endforeach()
|
||||
|
@ -161,8 +161,8 @@ if(SD_CUDA)
|
|||
endif()
|
||||
|
||||
if (CUDA_FOUND)
|
||||
message("CUDA include directory: ${CUDA_INCLUDE_DIRS}")
|
||||
include_directories(${CUDA_INCLUDE_DIRS})
|
||||
message("CUDA include directory: ${CUDA_INCLUDE_DIRS}")
|
||||
include_directories(${CUDA_INCLUDE_DIRS})
|
||||
message("CUDA found!")
|
||||
if ("${SD_EXPERIMENTAL}" STREQUAL "yes")
|
||||
message("Experimental mode ENABLED")
|
||||
|
@ -218,7 +218,7 @@ if(SD_CUDA)
|
|||
|
||||
|
||||
file(GLOB_RECURSE COMPILATION_UNITS false ../include/loops/cuda/compilation_units/*.cu.in
|
||||
../include/ops/impl/compilation_units/*.cpp.in)
|
||||
../include/ops/impl/compilation_units/*.cpp.in)
|
||||
|
||||
foreach(FL_ITEM ${COMPILATION_UNITS})
|
||||
genCompilation(FL_ITEM)
|
||||
|
@ -229,12 +229,12 @@ if(SD_CUDA)
|
|||
file(GLOB_RECURSE CUSTOMOPS_CUDNN_SOURCES false ../include/ops/declarable/platform/cudnn/*.cu)
|
||||
endif()
|
||||
|
||||
add_library(samediff_obj OBJECT ${LOOPS_SOURCES_CUDA} ${LEGACY_SOURCES}
|
||||
add_library(samediff_obj OBJECT ${LOOPS_SOURCES_CUDA} ${LEGACY_SOURCES}
|
||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||
${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
|
||||
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
||||
)
|
||||
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
||||
)
|
||||
|
||||
if (WIN32)
|
||||
message("MSVC runtime for library: ${MSVC_RT_LIB}")
|
||||
|
@ -266,10 +266,10 @@ if(SD_CUDA)
|
|||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj /std:c++14")
|
||||
endif()
|
||||
|
||||
target_link_libraries(${SD_LIBRARY_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN})
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda)
|
||||
target_link_libraries(${SD_LIBRARY_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN})
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda)
|
||||
|
||||
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
|
||||
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
|
||||
endif(CUDA_FOUND)
|
||||
elseif(SD_CPU)
|
||||
|
||||
|
@ -296,8 +296,8 @@ elseif(SD_CPU)
|
|||
|
||||
|
||||
file(GLOB_RECURSE COMPILATION_UNITS false ../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in
|
||||
../include/loops/cpu/compilation_units/*.cpp.in ../include/helpers/cpu/loops/*.cpp.in
|
||||
../include/ops/impl/compilation_units/*.cpp.in)
|
||||
../include/loops/cpu/compilation_units/*.cpp.in ../include/helpers/cpu/loops/*.cpp.in
|
||||
../include/ops/impl/compilation_units/*.cpp.in)
|
||||
|
||||
foreach(FL_ITEM ${COMPILATION_UNITS})
|
||||
genCompilation(FL_ITEM)
|
||||
|
@ -312,29 +312,29 @@ elseif(SD_CPU)
|
|||
|
||||
|
||||
if(SD_CHECK_VECTORIZATION)
|
||||
set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES})
|
||||
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||
set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES})
|
||||
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||
|
||||
if (CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
||||
set(CHECK_VECT_FLAGS "-ftree-vectorize -fsave-optimization-record")
|
||||
#to process fsave-optimization-record we will need our cython version code
|
||||
message("Build Auto vectorization helpers")
|
||||
execute_process(COMMAND "python3" "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/cython_setup.py" "build_ext" "--inplace" WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/" RESULT_VARIABLE ret)
|
||||
message("build='${ret}'")
|
||||
if (CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
||||
set(CHECK_VECT_FLAGS "-ftree-vectorize -fsave-optimization-record")
|
||||
#to process fsave-optimization-record we will need our cython version code
|
||||
message("Build Auto vectorization helpers")
|
||||
execute_process(COMMAND "python3" "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/cython_setup.py" "build_ext" "--inplace" WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/" RESULT_VARIABLE ret)
|
||||
message("build='${ret}'")
|
||||
|
||||
#remove fail cases that gcc fails produce sometimes
|
||||
file(GLOB_RECURSE FAILURE_CASES false ../include/loops/cpu/compilation_units/reduce3*.cpp)
|
||||
#message("*****${FAILURE_CASES}")
|
||||
foreach(FL_ITEM ${FAILURE_CASES})
|
||||
message("Removing failure cases ${FL_ITEM}")
|
||||
list(REMOVE_ITEM VECT_FILES ${FL_ITEM})
|
||||
endforeach()
|
||||
else()
|
||||
set(CHECK_VECT_FLAGS "-ftree-vectorize -fopt-info-vec-optimized-missed")
|
||||
#remove fail cases that gcc fails produce sometimes
|
||||
file(GLOB_RECURSE FAILURE_CASES false ../include/loops/cpu/compilation_units/reduce3*.cpp)
|
||||
#message("*****${FAILURE_CASES}")
|
||||
foreach(FL_ITEM ${FAILURE_CASES})
|
||||
message("Removing failure cases ${FL_ITEM}")
|
||||
list(REMOVE_ITEM VECT_FILES ${FL_ITEM})
|
||||
endforeach()
|
||||
else()
|
||||
set(CHECK_VECT_FLAGS "-ftree-vectorize -fopt-info-vec-optimized-missed")
|
||||
endif()
|
||||
message("CHECK VECTORIZATION ${CHECK_VECT_FLAGS}")
|
||||
set_source_files_properties( ${VECT_FILES} PROPERTIES COMPILE_FLAGS "${CHECK_VECT_FLAGS}" )
|
||||
endif()
|
||||
message("CHECK VECTORIZATION ${CHECK_VECT_FLAGS}")
|
||||
set_source_files_properties( ${VECT_FILES} PROPERTIES COMPILE_FLAGS "${CHECK_VECT_FLAGS}" )
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message("CPU BLAS")
|
||||
|
@ -373,7 +373,11 @@ elseif(SD_CPU)
|
|||
foreach (_variableName ${_variableNames})
|
||||
message(STATUS "${_variableName}=${${_variableName}}")
|
||||
endforeach()
|
||||
target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||
|
||||
#This breaks the build. Normally you want to run tests anyways.
|
||||
if(NOT "$ENV{CLION_IDE}")
|
||||
target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||
endif()
|
||||
|
||||
if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}")
|
||||
message(STATUS "Building minifier...")
|
||||
|
@ -382,7 +386,7 @@ elseif(SD_CPU)
|
|||
endif()
|
||||
|
||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
||||
message(FATAL_ERROR "You need at least GCC 4.9")
|
||||
message(FATAL_ERROR "You need at least GCC 4.9")
|
||||
endif()
|
||||
|
||||
# OpenMP works well pretty much only with GCC
|
||||
|
|
|
@ -26,132 +26,138 @@
|
|||
#include <ops/declarable/CustomOperations.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace ops {
|
||||
|
||||
DECLARE_TYPES(fused_batch_norm) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(sd::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) {
|
||||
auto x = INPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
|
||||
auto scale = INPUT_VARIABLE(1); // [iD]
|
||||
auto offset = INPUT_VARIABLE(2); // [iD]
|
||||
|
||||
auto y = OUTPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
|
||||
auto batchMean = OUTPUT_VARIABLE(1); // [iD]
|
||||
auto batchVar = OUTPUT_VARIABLE(2); // [iD]
|
||||
|
||||
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
|
||||
const bool isTraining = (bool)INT_ARG(1);
|
||||
|
||||
REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf());
|
||||
|
||||
int bS = x->sizeAt(0); // batch size
|
||||
int iH, iW, iD; // input height, input width, input depth(number of channels)
|
||||
if(dataFormat) {
|
||||
iD = x->sizeAt(1);
|
||||
iH = x->sizeAt(2);
|
||||
iW = x->sizeAt(3);
|
||||
}
|
||||
else {
|
||||
iD = x->sizeAt(3);
|
||||
iH = x->sizeAt(1);
|
||||
iW = x->sizeAt(2);
|
||||
}
|
||||
|
||||
auto xCast = x->cast(sd::DataType::FLOAT32);
|
||||
|
||||
|
||||
REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str());
|
||||
REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str());
|
||||
|
||||
NDArray *mean(nullptr), *variance(nullptr);
|
||||
if(!isTraining) {
|
||||
mean = INPUT_VARIABLE(3);
|
||||
variance = INPUT_VARIABLE(4);
|
||||
REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(mean).c_str());
|
||||
REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str());
|
||||
}
|
||||
else {
|
||||
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
|
||||
std::vector<Nd4jLong> shape = {iD};
|
||||
mean = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext());
|
||||
variance = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext());
|
||||
}
|
||||
|
||||
|
||||
float epsilon;
|
||||
if(block.getTArguments()->size() > 0) {
|
||||
epsilon = (float) (T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5);
|
||||
}
|
||||
else {
|
||||
epsilon = 0.001f;
|
||||
}
|
||||
|
||||
const int restSize = x->lengthOf() / iD;
|
||||
|
||||
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, sd::DataType::FLOAT32, block.launchContext());
|
||||
xAffected.assign(xCast);
|
||||
|
||||
const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;
|
||||
const float restSizeInv = 1.0f / restSize;
|
||||
const float restSizeAdjust = (float)restSize / restSizeMinusOne;
|
||||
|
||||
if(isTraining) {
|
||||
auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
|
||||
sum *= restSizeInv;
|
||||
mean->assign(sum);
|
||||
*batchMean = *mean;
|
||||
}
|
||||
else
|
||||
*batchMean = 0.;
|
||||
|
||||
auto xCentered = xAffected - *mean;
|
||||
xAffected -= *mean;
|
||||
|
||||
if(isTraining) {
|
||||
int power = 2;
|
||||
xAffected.applyScalar(scalar::Pow, power, xAffected);
|
||||
auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
|
||||
sum *= restSizeInv;
|
||||
variance->assign(sum);
|
||||
auto varOutput = (*variance) * restSizeAdjust;
|
||||
batchVar->assign(varOutput);
|
||||
}
|
||||
else
|
||||
*batchVar = 0.;
|
||||
|
||||
auto scaledVariance = ((*variance + epsilon).transform(transform::RSqrt) * (*scale)).cast(xAffected.dataType());
|
||||
auto xScaled1 = xCentered * scaledVariance;
|
||||
auto xShifted1 = xScaled1 + *offset;
|
||||
|
||||
y->assign(xShifted1);
|
||||
|
||||
if(isTraining) {
|
||||
delete mean;
|
||||
delete variance;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
|
||||
DECLARE_SHAPE_FN(fused_batch_norm) {
|
||||
auto xShapeInfo = inputShape->at(0);
|
||||
auto scaleShapeInfo = inputShape->at(1);
|
||||
|
||||
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
|
||||
const int iD = dataFormat ? xShapeInfo[2] : xShapeInfo[4];
|
||||
|
||||
REQUIRE_TRUE(scaleShapeInfo[0] == 1 && scaleShapeInfo[1] == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scaleShapeInfo).c_str());
|
||||
|
||||
Nd4jLong* outShapeInfo(nullptr), *batchMeanShapeInfo(nullptr), *batchVarShapeInfo(nullptr);
|
||||
|
||||
COPY_SHAPE(xShapeInfo, outShapeInfo);
|
||||
COPY_SHAPE(scaleShapeInfo, batchMeanShapeInfo);
|
||||
COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo);
|
||||
|
||||
return SHAPELIST(CONSTANT(outShapeInfo), CONSTANT(batchMeanShapeInfo), CONSTANT(batchVarShapeInfo));
|
||||
}
|
||||
|
||||
DECLARE_TYPES(fused_batch_norm) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(sd::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) {
|
||||
auto x = INPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
|
||||
auto scale = INPUT_VARIABLE(1); // [iD]
|
||||
auto offset = INPUT_VARIABLE(2); // [iD]
|
||||
|
||||
auto y = OUTPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
|
||||
auto batchMean = OUTPUT_VARIABLE(1); // [iD]
|
||||
auto batchVar = OUTPUT_VARIABLE(2); // [iD]
|
||||
|
||||
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
|
||||
const bool isTraining = (bool)INT_ARG(1);
|
||||
|
||||
REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf());
|
||||
|
||||
int bS = x->sizeAt(0); // batch size
|
||||
int iH, iW, iD; // input height, input width, input depth(number of channels)
|
||||
if(dataFormat) {
|
||||
iD = x->sizeAt(1);
|
||||
iH = x->sizeAt(2);
|
||||
iW = x->sizeAt(3);
|
||||
}
|
||||
else {
|
||||
iD = x->sizeAt(3);
|
||||
iH = x->sizeAt(1);
|
||||
iW = x->sizeAt(2);
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str());
|
||||
REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str());
|
||||
|
||||
NDArray *mean(nullptr), *variance(nullptr);
|
||||
if(!isTraining){
|
||||
mean = INPUT_VARIABLE(3);
|
||||
variance = INPUT_VARIABLE(4);
|
||||
REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(mean).c_str());
|
||||
REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str());
|
||||
}
|
||||
else {
|
||||
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
|
||||
std::vector<Nd4jLong> shape = {iD};
|
||||
mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
|
||||
variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
|
||||
}
|
||||
|
||||
// FIXME: double?
|
||||
double epsilon;
|
||||
if(block.getTArguments()->size() > 0)
|
||||
epsilon = T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5;
|
||||
else
|
||||
epsilon = 0.001;
|
||||
|
||||
const int restSize = x->lengthOf() / iD;
|
||||
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext());
|
||||
xAffected.assign(x);
|
||||
|
||||
const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;
|
||||
// FIXME: float?
|
||||
const double restSizeInv = 1.0 / restSize;
|
||||
const double restSizeAdjust = (double)restSize / restSizeMinusOne;
|
||||
|
||||
if(isTraining) {
|
||||
auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
|
||||
sum *= restSizeInv;
|
||||
mean->assign(sum);
|
||||
*batchMean = *mean;
|
||||
//delete sum;
|
||||
}
|
||||
else
|
||||
*batchMean = 0.;
|
||||
|
||||
xAffected -= *mean;
|
||||
|
||||
if(isTraining) {
|
||||
int power = 2;
|
||||
xAffected.applyScalar(scalar::Pow, power, xAffected);
|
||||
auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
|
||||
sum *= restSizeInv;
|
||||
variance->assign(sum);
|
||||
*batchVar = (*variance) * restSizeAdjust;
|
||||
//delete sum;
|
||||
}
|
||||
else
|
||||
*batchVar = 0.;
|
||||
xAffected *= (*variance + epsilon).transform(transform::RSqrt) * (*scale) + (*offset);
|
||||
y->assign( xAffected );
|
||||
|
||||
if(isTraining) {
|
||||
delete mean;
|
||||
delete variance;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
|
||||
DECLARE_SHAPE_FN(fused_batch_norm) {
|
||||
auto xShapeInfo = inputShape->at(0);
|
||||
auto scaleShapeInfo = inputShape->at(1);
|
||||
|
||||
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
|
||||
const int iD = dataFormat ? xShapeInfo[2] : xShapeInfo[4];
|
||||
|
||||
REQUIRE_TRUE(scaleShapeInfo[0] == 1 && scaleShapeInfo[1] == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scaleShapeInfo).c_str());
|
||||
|
||||
Nd4jLong* outShapeInfo(nullptr), *batchMeanShapeInfo(nullptr), *batchVarShapeInfo(nullptr);
|
||||
|
||||
COPY_SHAPE(xShapeInfo, outShapeInfo);
|
||||
COPY_SHAPE(scaleShapeInfo, batchMeanShapeInfo);
|
||||
COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo);
|
||||
|
||||
return SHAPELIST(CONSTANT(outShapeInfo), CONSTANT(batchMeanShapeInfo), CONSTANT(batchVarShapeInfo));
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -87,9 +87,12 @@ public class FusedBatchNorm extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
if(!dArguments.isEmpty()) {
|
||||
return Arrays.asList(dArguments.get(0),dArguments.get(0),dArguments.get(0));
|
||||
}
|
||||
return Arrays.asList(outputDataType == null ? DataType.FLOAT : outputDataType,
|
||||
outputDataType == null ? DataType.FLOAT : outputDataType,
|
||||
outputDataType == null ? DataType.FLOAT : outputDataType);
|
||||
|
|
|
@ -69,10 +69,8 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
* the status of the test failing. No tests will run.
|
||||
*/
|
||||
public final static List<String> EXECUTE_ONLY_MODELS = Arrays.asList(
|
||||
"max_pool_with_argmax/int32_int64_padding_SAME",
|
||||
// "fused_batch_norm/float32_nhwc",
|
||||
"max_pool_with_argmax/int64_int64_padding_SAME"
|
||||
// "fused_batch_norm/float16_nhwc",
|
||||
"fused_batch_norm/float32_nhwc"
|
||||
// , "fused_batch_norm/float16_nhwc"
|
||||
|
||||
);
|
||||
|
||||
|
@ -86,9 +84,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
// Still failing 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: TruncateMod
|
||||
"truncatemod/.*",
|
||||
|
||||
//Still failing as of 2019/09/11 - https://github.com/deeplearning4j/deeplearning4j/issues/6464 - not sure if related to: https://github.com/deeplearning4j/deeplearning4j/issues/6447
|
||||
"cnn2d_nn/nhwc_b1_k12_s12_d12_SAME",
|
||||
|
||||
//2019/09/11 - No tensorflow op found for SparseTensorDenseAdd
|
||||
// 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: SparseTensorDenseAdd
|
||||
"confusion/.*",
|
||||
|
|
|
@ -958,7 +958,7 @@ val fusedBatchnormV1 = TensorflowMappingProcess(
|
|||
"offset" to "offset","mean" to "mean","variance" to "variance"))),
|
||||
inputFrameworkOpName = "FusedBatchNorm",
|
||||
opMappingRegistry = tensorflowOpRegistry,
|
||||
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon")),
|
||||
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon","dtype" to "T")),
|
||||
invertBooleanNumber(mutableMapOf("isTraining" to "is_training")),
|
||||
stringEqualsRule(outputAttribute = "dataFormat",inputFrameworkAttributeName = "data_format",valueToTest = "NCHW",argumentIndex = 0))
|
||||
)
|
||||
|
@ -971,7 +971,7 @@ val fusedBatchnormV2 = TensorflowMappingProcess(
|
|||
"offset" to "offset","mean" to "mean","variance" to "variance"))),
|
||||
inputFrameworkOpName = "FusedBatchNormV2",
|
||||
opMappingRegistry = tensorflowOpRegistry,
|
||||
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon")),
|
||||
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon","dtype" to "T")),
|
||||
invertBooleanNumber(mutableMapOf("isTraining" to "is_training")),
|
||||
stringEqualsRule(outputAttribute = "dataFormat",inputFrameworkAttributeName = "data_format",valueToTest = "NCHW",argumentIndex = 0))
|
||||
)
|
||||
|
@ -983,7 +983,7 @@ val fusedBatchnormV3 = TensorflowMappingProcess(
|
|||
"offset" to "offset","mean" to "mean","variance" to "variance"))),
|
||||
inputFrameworkOpName = "FusedBatchNormV3",
|
||||
opMappingRegistry = tensorflowOpRegistry,
|
||||
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon")),
|
||||
attributeMappingRules = listOf(valueMapping(mutableMapOf("epsilon" to "epsilon","dtype" to "T")),
|
||||
invertBooleanNumber(mutableMapOf("isTraining" to "is_training")),
|
||||
stringEqualsRule(outputAttribute = "dataFormat",inputFrameworkAttributeName = "data_format",valueToTest = "NCHW",argumentIndex = 0))
|
||||
)
|
||||
|
|
|
@ -8367,10 +8367,16 @@ mappings {
|
|||
functionName: "valuemapping"
|
||||
inputFloatName: "epsilon"
|
||||
outputDoubleName: "epsilon"
|
||||
inputDataTypeName: "T"
|
||||
outputDataTypeName: "dtype"
|
||||
inputToOutput {
|
||||
key: "epsilon"
|
||||
value: "epsilon"
|
||||
}
|
||||
inputToOutput {
|
||||
key: "dtype"
|
||||
value: "T"
|
||||
}
|
||||
ruleType: "attribute"
|
||||
inputFrameworkOpName: "FusedBatchNorm"
|
||||
}
|
||||
|
@ -12480,10 +12486,16 @@ mappings {
|
|||
functionName: "valuemapping"
|
||||
inputFloatName: "epsilon"
|
||||
outputDoubleName: "epsilon"
|
||||
inputDataTypeName: "T"
|
||||
outputDataTypeName: "dtype"
|
||||
inputToOutput {
|
||||
key: "epsilon"
|
||||
value: "epsilon"
|
||||
}
|
||||
inputToOutput {
|
||||
key: "dtype"
|
||||
value: "T"
|
||||
}
|
||||
ruleType: "attribute"
|
||||
inputFrameworkOpName: "FusedBatchNormV3"
|
||||
}
|
||||
|
@ -13056,10 +13068,16 @@ mappings {
|
|||
functionName: "valuemapping"
|
||||
inputFloatName: "epsilon"
|
||||
outputDoubleName: "epsilon"
|
||||
inputDataTypeName: "T"
|
||||
outputDataTypeName: "dtype"
|
||||
inputToOutput {
|
||||
key: "epsilon"
|
||||
value: "epsilon"
|
||||
}
|
||||
inputToOutput {
|
||||
key: "dtype"
|
||||
value: "T"
|
||||
}
|
||||
ruleType: "attribute"
|
||||
inputFrameworkOpName: "FusedBatchNormV2"
|
||||
}
|
||||
|
|
|
@ -90,7 +90,9 @@ class TestTensorflowIR {
|
|||
//val inputMap = mapOf("image" to Nd4j.ones(1,128,128,4))
|
||||
val inputMap = emptyMap<String,INDArray>()
|
||||
val tensorflowIRGraph = TensorflowIRGraph(textGraph,tensorflowOps,tfImporter.registry)
|
||||
val outputList = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }.toSet()
|
||||
val outputList = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }.toMutableSet()
|
||||
outputList.add("FusedBatchNormV3:1")
|
||||
outputList.add("FusedBatchNormV3:2")
|
||||
val tfGraphRunner = TensorflowIRGraphRunner(tensorflowIRGraph, inputMap.keys.toList(), outputList.toList())
|
||||
val importedGraph = TFGraphMapper.importGraph(textGraph)
|
||||
val graph = tfImporter.importFromGraph(textGraph,inputMap)
|
||||
|
@ -104,7 +106,7 @@ class TestTensorflowIR {
|
|||
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }
|
||||
val skipValidation = setOf("parallel_stack/ExpandDims/dim")
|
||||
//assertEquals(output.keys,output2.keys)
|
||||
val notEquals = HashSet<String>()
|
||||
/* val notEquals = HashSet<String>()
|
||||
names.forEach {
|
||||
val value = output[it]
|
||||
val value2 = output2[it]
|
||||
|
@ -115,9 +117,9 @@ class TestTensorflowIR {
|
|||
val newVar = graph.variables[it]
|
||||
notEquals.add(it)
|
||||
}
|
||||
}
|
||||
}*/
|
||||
|
||||
println(notEquals)
|
||||
//println(notEquals)
|
||||
|
||||
// assertEquals(output,output2)
|
||||
//assertEquals(tfOutput,output)
|
||||
|
|
|
@ -8367,10 +8367,16 @@ mappings {
|
|||
functionName: "valuemapping"
|
||||
inputFloatName: "epsilon"
|
||||
outputDoubleName: "epsilon"
|
||||
inputDataTypeName: "T"
|
||||
outputDataTypeName: "dtype"
|
||||
inputToOutput {
|
||||
key: "epsilon"
|
||||
value: "epsilon"
|
||||
}
|
||||
inputToOutput {
|
||||
key: "dtype"
|
||||
value: "T"
|
||||
}
|
||||
ruleType: "attribute"
|
||||
inputFrameworkOpName: "FusedBatchNorm"
|
||||
}
|
||||
|
@ -12480,10 +12486,16 @@ mappings {
|
|||
functionName: "valuemapping"
|
||||
inputFloatName: "epsilon"
|
||||
outputDoubleName: "epsilon"
|
||||
inputDataTypeName: "T"
|
||||
outputDataTypeName: "dtype"
|
||||
inputToOutput {
|
||||
key: "epsilon"
|
||||
value: "epsilon"
|
||||
}
|
||||
inputToOutput {
|
||||
key: "dtype"
|
||||
value: "T"
|
||||
}
|
||||
ruleType: "attribute"
|
||||
inputFrameworkOpName: "FusedBatchNormV3"
|
||||
}
|
||||
|
@ -13056,10 +13068,16 @@ mappings {
|
|||
functionName: "valuemapping"
|
||||
inputFloatName: "epsilon"
|
||||
outputDoubleName: "epsilon"
|
||||
inputDataTypeName: "T"
|
||||
outputDataTypeName: "dtype"
|
||||
inputToOutput {
|
||||
key: "epsilon"
|
||||
value: "epsilon"
|
||||
}
|
||||
inputToOutput {
|
||||
key: "dtype"
|
||||
value: "T"
|
||||
}
|
||||
ruleType: "attribute"
|
||||
inputFrameworkOpName: "FusedBatchNormV2"
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue