Platform helpers (#8216)
* platform helpers draft Signed-off-by: raver119 <raver119@gmail.com> * typo Signed-off-by: raver119 <raver119@gmail.com> * disable platform cmake Signed-off-by: raver119 <raver119@gmail.com> * another draft Signed-off-by: raver119 <raver119@gmail.com> * mkldnn convolution refactored Signed-off-by: raver119 <raver119@gmail.com> * minor tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more safety check Signed-off-by: raver119 <raver119@gmail.com> * prototype works Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * force static library mode for mkldnn Signed-off-by: raver119 <raver119@gmail.com> * - ismax fix - experimental arg fix - don't enforce openblas on Apple hardware Signed-off-by: raver119 <raver119@gmail.com> * bunch of small fixes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * declare concurrent Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - MKLDNN version upgrade to 1.0.2 - avgpool2d/maxpool2d APIs update Signed-off-by: raver119 <raver119@gmail.com> * - avgpool2d_bp/maxpool2d_bp APIs update Signed-off-by: raver119 <raver119@gmail.com> * - conv2d/batchnorm APIs update Signed-off-by: raver119 <raver119@gmail.com> * - lrn/conv2d_bp/conv3d/conv3d_bp APIs update Signed-off-by: raver119 <raver119@gmail.com> * all ops converted to MKLDNN 1.x Signed-off-by: raver119 <raver119@gmail.com> * bunch of tweaks Signed-off-by: raver119 <raver119@gmail.com> * namespace for platform helpers Signed-off-by: raver119 <raver119@gmail.com> * make sure platform helpers aren't opimized out Signed-off-by: raver119 <raver119@gmail.com> * build cpu_features on x86 systems Signed-off-by: raver119 <raver119@gmail.com> * build cpu_features on x86 systems Signed-off-by: raver119 <raver119@gmail.com> * more of cpu_features Signed-off-by: raver119 <raver119@gmail.com> * - mkldnn removed from java - cpu_features checks in CpuNDArrayFactory Signed-off-by: raver119 <raver119@gmail.com> * F16C definition renamed Signed-off-by: raver119 <raver119@gmail.com> * some mkldnn rearrangements Signed-off-by: raver119 <raver119@gmail.com> * check supported instructions before doing anything Signed-off-by: raver119 <raver119@gmail.com> * typo Signed-off-by: raver119 <raver119@gmail.com> * missied impl Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC option Signed-off-by: raver119 <raver119@gmail.com> * conv2d fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d_bp fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool2d_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * maxpool bp leaks fixed Signed-off-by: raver119 <raver119@gmail.com> * printf removed Signed-off-by: raver119 <raver119@gmail.com> * batchnorm fix Signed-off-by: raver119 <raver119@gmail.com> * AVX warning/error polishing Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * More polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * Polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * remove previous MKL-DNN support layer Signed-off-by: raver119 <raver119@gmail.com> * avx2 tweak Signed-off-by: raver119 <raver119@gmail.com> * allow static for apple Signed-off-by: raver119@gmail.com <raver119@gmail.com> * exclude mkldnn in one more place Signed-off-by: raver119 <raver119@gmail.com> * exclude mkldnn in one more place Signed-off-by: raver119 <raver119@gmail.com> * restore OPENBLAS_PATH use Signed-off-by: raver119 <raver119@gmail.com> * add runtime check for avx/avx2 support Signed-off-by: raver119 <raver119@gmail.com> * convolution_auto Signed-off-by: raver119 <raver119@gmail.com> * Add logic for helper argument * minor test fix Signed-off-by: raver119 <raver119@gmail.com> * few tweaks Signed-off-by: raver119 <raver119@gmail.com> * few tweaks Signed-off-by: raver119 <raver119@gmail.com> * skip OpTracker props for non-x86 builds Signed-off-by: raver119 <raver119@gmail.com> * linux arm isn't x86 :) Signed-off-by: raver119 <raver119@gmail.com> * avx-512 Signed-off-by: raver119 <raver119@gmail.com> * CUDA presets fix Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC Signed-off-by: raver119 <raver119@gmail.com> * prefetchw for avx2 Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC again Signed-off-by: raver119 <raver119@gmail.com>master
parent
ffae024cda
commit
98e2814879
|
@ -8,12 +8,19 @@ set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS OFF)
|
||||||
|
|
||||||
option(BUILD_TESTS "Build tests" OFF)
|
option(BUILD_TESTS "Build tests" OFF)
|
||||||
|
|
||||||
|
set(X86_BUILD false)
|
||||||
|
|
||||||
|
if (NOT IOS_BUILD AND NOT ANDROID_BUILD AND NOT ${ARCH} MATCHES "power*" AND NOT ${ARCH} MATCHES "arm*")
|
||||||
|
set(X86_BUILD true)
|
||||||
|
endif()
|
||||||
|
|
||||||
# -fsanitize=address
|
# -fsanitize=address
|
||||||
# -fsanitize=leak
|
# -fsanitize=leak
|
||||||
if (APPLE)
|
if (APPLE)
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D__APPLE_OS__=true -D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D__APPLE_OS__=true")
|
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D__APPLE_OS__=true")
|
||||||
elseif(WIN32)
|
elseif(WIN32)
|
||||||
|
set(X86_BUILD true)
|
||||||
if (NOT CUDA_BLAS)
|
if (NOT CUDA_BLAS)
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D_RELEASE=true")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2")
|
set(CMAKE_CXX_FLAGS_DEBUG " -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2")
|
||||||
|
@ -32,6 +39,7 @@ endif()
|
||||||
|
|
||||||
if(NATIVE)
|
if(NATIVE)
|
||||||
IF(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
|
IF(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
|
||||||
|
set(X86_BUILD false)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mcpu=native")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mcpu=native")
|
||||||
ELSE()
|
ELSE()
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native")
|
||||||
|
@ -39,21 +47,101 @@ if(NATIVE)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(NOT CUDA_BLAS)
|
if(NOT CUDA_BLAS)
|
||||||
if (NOT "${MKLDNN_PATH}" STREQUAL "")
|
# we need this definition to avoid global memory use within mkldnn
|
||||||
set(HAVE_MKLDNN 1)
|
add_definitions(-DMKLDNN_ENABLE_CONCURRENT_EXEC=true)
|
||||||
include_directories(${MKLDNN_PATH}/include/)
|
|
||||||
link_directories(${MKLDNN_PATH} ${MKLDNN_PATH}/lib/)
|
# there's a chance, we have no BLAS provided externally
|
||||||
IF(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
|
if ("${OPENBLAS_PATH}" STREQUAL "")
|
||||||
set(MKLDNN_LIBRARIES mkldnn mklml_intel)
|
#we don't want static OpenBLAS on Apple
|
||||||
else()
|
set(BLA_STATIC ON)
|
||||||
set(MKLDNN_LIBRARIES mkldnn mklml)
|
if (NOT APPLE)
|
||||||
|
set(BLA_VENDOR "OpenBLAS")
|
||||||
endif()
|
endif()
|
||||||
elseif (NOT "${OPENBLAS_PATH}" STREQUAL "")
|
|
||||||
|
# look around for system blas instead
|
||||||
|
find_package(BLAS REQUIRED)
|
||||||
|
if (BLAS_FOUND)
|
||||||
|
message("Original library: ${BLAS_LIBRARIES}")
|
||||||
|
# workaround for for cmake being unable to find static blas library
|
||||||
|
SET(_TMP_B "")
|
||||||
|
if (APPLE)
|
||||||
|
string(REGEX REPLACE "\\.dylib$" ".lib" _TMP_B "${BLAS_LIBRARIES}")
|
||||||
|
elseif (WIN32)
|
||||||
|
string(REGEX REPLACE "\\.dll" ".lib" _TMP_B "${BLAS_LIBRARIES}")
|
||||||
|
else()
|
||||||
|
string(REGEX REPLACE "\\.so$" ".a" _TMP_B "${BLAS_LIBRARIES}")
|
||||||
|
endif()
|
||||||
|
set(BLAS_LIBRARIES "${_TMP_B}")
|
||||||
|
|
||||||
|
message("Found external BLAS implementation: ${BLAS_LIBRARIES} ")
|
||||||
|
add_definitions(-D__EXTERNAL_BLAS__=true)
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
# if we have externally provided OPENBLAS_PATH - let's use it
|
||||||
set(HAVE_OPENBLAS 1)
|
set(HAVE_OPENBLAS 1)
|
||||||
include_directories(${OPENBLAS_PATH}/include/)
|
include_directories(${OPENBLAS_PATH}/include/)
|
||||||
link_directories(${OPENBLAS_PATH} ${OPENBLAS_PATH}/lib/)
|
link_directories(${OPENBLAS_PATH} ${OPENBLAS_PATH}/lib/)
|
||||||
set(OPENBLAS_LIBRARIES openblas)
|
set(OPENBLAS_LIBRARIES openblas)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# building cpu_features
|
||||||
|
if (X86_BUILD)
|
||||||
|
add_definitions(-DCPU_FEATURES=true)
|
||||||
|
set(BUILD_PIC "ON" CACHE STRING "Hack to enforce fPIC mode" FORCE)
|
||||||
|
configure_file(./CMakeLists.txt.cpu_features.in cpu_features-download/CMakeLists.txt)
|
||||||
|
message("CMAKE_COMMAND: ${CMAKE_COMMAND}")
|
||||||
|
execute_process(COMMAND ${CMAKE_COMMAND} -DBUILD_PIC=ON -G "${CMAKE_GENERATOR}" .
|
||||||
|
RESULT_VARIABLE result
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/cpu_features-download )
|
||||||
|
|
||||||
|
if(result)
|
||||||
|
message(FATAL_ERROR "CMake step for cpu_features failed: ${result}")
|
||||||
|
endif()
|
||||||
|
execute_process(COMMAND ${CMAKE_COMMAND} --build .
|
||||||
|
RESULT_VARIABLE result
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/cpu_features-download )
|
||||||
|
if(result)
|
||||||
|
message(FATAL_ERROR "Build step for cpu_features failed: ${result}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/cpu_features-src
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/cpu_features-build
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
set(CPUF_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cpu_features-src)
|
||||||
|
include_directories(${CPUF_SOURCE_DIR}/include)
|
||||||
|
set(CPU_FEATURES cpu_features)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# new mkl-dnn entry
|
||||||
|
if (${HELPERS_mkldnn})
|
||||||
|
message("Going to pull & build mkldnn")
|
||||||
|
set(HAVE_MKLDNN 1)
|
||||||
|
set(MKLDNN_LIBRARY_TYPE "STATIC" CACHE STRING "Hack to enforce static mode" FORCE)
|
||||||
|
|
||||||
|
configure_file(./CMakeLists.txt.mkldnn.in mkldnn-download/CMakeLists.txt)
|
||||||
|
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
|
||||||
|
RESULT_VARIABLE result
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-download )
|
||||||
|
if(result)
|
||||||
|
message(FATAL_ERROR "CMake step for mkldnn failed: ${result}")
|
||||||
|
endif()
|
||||||
|
execute_process(COMMAND ${CMAKE_COMMAND} --build .
|
||||||
|
RESULT_VARIABLE result
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-download )
|
||||||
|
if(result)
|
||||||
|
message(FATAL_ERROR "Build step for mkldnn failed: ${result}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
|
||||||
|
set(mkldnn_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build)
|
||||||
|
set(mkldnn_EXT_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src)
|
||||||
|
set(MKLDNN_PATH "${mkldnn_SOURCE_DIR}")
|
||||||
|
include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_EXT_DIR}/include ${mkldnn_SOURCE_DIR})
|
||||||
|
set(MKLDNN mkldnn)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Download and unpack flatbuffers at configure time
|
# Download and unpack flatbuffers at configure time
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
cmake_minimum_required(VERSION 2.8.2)
|
||||||
|
|
||||||
|
project(mkldnn-download NONE)
|
||||||
|
|
||||||
|
include(ExternalProject)
|
||||||
|
ExternalProject_Add(mkldnn
|
||||||
|
GIT_REPOSITORY https://github.com/google/cpu_features.git
|
||||||
|
GIT_TAG v0.4.1
|
||||||
|
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/cpu_features-src"
|
||||||
|
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/cpu_features-build"
|
||||||
|
CONFIGURE_COMMAND ""
|
||||||
|
CMAKE_ARGS "-DBUILD_PIC=ON"
|
||||||
|
BUILD_COMMAND ""
|
||||||
|
INSTALL_COMMAND ""
|
||||||
|
TEST_COMMAND ""
|
||||||
|
)
|
|
@ -5,11 +5,11 @@ project(mkldnn-download NONE)
|
||||||
include(ExternalProject)
|
include(ExternalProject)
|
||||||
ExternalProject_Add(mkldnn
|
ExternalProject_Add(mkldnn
|
||||||
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
||||||
GIT_TAG v0.18.1
|
GIT_TAG v1.0.2
|
||||||
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
||||||
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
||||||
CONFIGURE_COMMAND "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src/scripts/prepare_mkl.sh"
|
CONFIGURE_COMMAND ""
|
||||||
CMAKE_ARGS -DMKLDNN_USE_MKL=ML -G \"Unix Makefiles\" -DMKLDNN_LIBRARY_TYPE=STATIC
|
CMAKE_ARGS -DMKLDNN_USE_MKL=ML -DMKLDNN_LIBRARY_TYPE=STATIC -G \"Unix Makefiles\"
|
||||||
BUILD_COMMAND ""
|
BUILD_COMMAND ""
|
||||||
INSTALL_COMMAND ""
|
INSTALL_COMMAND ""
|
||||||
TEST_COMMAND ""
|
TEST_COMMAND ""
|
|
@ -78,14 +78,24 @@ IF(${ARCH} MATCHES "arm*")
|
||||||
ELSEIF(${ARCH} MATCHES "power*")
|
ELSEIF(${ARCH} MATCHES "power*")
|
||||||
set(ARCH_TUNE "-mcpu=${ARCH} -mtune=${ARCH} -D__POWER")
|
set(ARCH_TUNE "-mcpu=${ARCH} -mtune=${ARCH} -D__POWER")
|
||||||
ELSEIF(${EXTENSION} MATCHES "avx2")
|
ELSEIF(${EXTENSION} MATCHES "avx2")
|
||||||
set(ARCH_TUNE "-march=${ARCH} -mtune=${ARCH} -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -D__F16C__=true")
|
message("Building AVX2 binary...")
|
||||||
|
set(ARCH_TUNE "-mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mprefetchwt1 -DSD_F16C=true -DF_AVX2=true")
|
||||||
ELSE()
|
ELSE()
|
||||||
if ("${ARCH}" STREQUAL "x86-64")
|
if ("${ARCH}" STREQUAL "x86-64")
|
||||||
|
message("Building x86_64 binary...")
|
||||||
set(ARCH_TYPE "generic")
|
set(ARCH_TYPE "generic")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DF_X64=true")
|
||||||
else()
|
else()
|
||||||
set(ARCH_TYPE "${ARCH}")
|
set(ARCH_TYPE "${ARCH}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
IF(${EXTENSION} MATCHES "avx512")
|
||||||
|
message("Building AVX512 binary...")
|
||||||
|
# we need to set flag here, that we can use hardware f16 conversion + tell that cpu features should be tracked
|
||||||
|
message("Current CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512vl -mavx512bw -mavx512dq -mavx512cd -mbmi -mbmi2 -mprefetchwt1 -mclflushopt -mxsavec -mxsaves -DSD_F16C=true -DF_AVX512=true")
|
||||||
|
endif()
|
||||||
|
|
||||||
set(ARCH_TUNE "-march=${ARCH} -mtune=${ARCH_TYPE}")
|
set(ARCH_TUNE "-march=${ARCH} -mtune=${ARCH_TYPE}")
|
||||||
ENDIF()
|
ENDIF()
|
||||||
|
|
||||||
|
@ -299,19 +309,31 @@ elseif(CPU_BLAS)
|
||||||
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/*.cpp ../include/memory/*.h)
|
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/*.cpp ../include/memory/*.h)
|
||||||
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.h)
|
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.h)
|
||||||
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
|
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
|
||||||
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp ../include/ops/declarable/helpers/impl/*.cpp)
|
file(GLOB_RECURSE CUSTOMOPS_GENERIC_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp ../include/ops/declarable/helpers/impl/*.cpp)
|
||||||
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
|
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
|
||||||
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
||||||
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/*.h)
|
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/*.h)
|
||||||
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
||||||
|
|
||||||
|
|
||||||
|
#if MKLDNN is enabled - we're building mkldnn-powered helpers
|
||||||
|
if (HAVE_MKLDNN)
|
||||||
|
file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (X86_BUILD)
|
||||||
|
#we disable platform optimizations for certains files
|
||||||
|
set_source_files_properties(cpu/NativeOps.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
||||||
|
set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
||||||
|
endif()
|
||||||
|
|
||||||
message("CPU BLAS")
|
message("CPU BLAS")
|
||||||
add_definitions(-D__CPUBLAS__=true)
|
add_definitions(-D__CPUBLAS__=true)
|
||||||
add_library(nd4jobj OBJECT cpu/NativeOps.cpp cpu/GraphExecutioner.cpp
|
add_library(nd4jobj OBJECT cpu/NativeOps.cpp cpu/GraphExecutioner.cpp
|
||||||
cpu/NativeOpExecutioner.cpp cpu/NDArray.cpp cpu/NDArrayFactory.cpp
|
cpu/NativeOpExecutioner.cpp cpu/NDArray.cpp cpu/NDArrayFactory.cpp
|
||||||
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
||||||
Environment.cpp Environment.h ${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
Environment.cpp Environment.h ${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_HELPERS_SOURCES}
|
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
||||||
${OPS_SOURCES} ${PERF_SOURCES})
|
${OPS_SOURCES} ${PERF_SOURCES})
|
||||||
if(IOS)
|
if(IOS)
|
||||||
add_library(${LIBND4J_NAME} STATIC $<TARGET_OBJECTS:nd4jobj>)
|
add_library(${LIBND4J_NAME} STATIC $<TARGET_OBJECTS:nd4jobj>)
|
||||||
|
@ -320,12 +342,13 @@ elseif(CPU_BLAS)
|
||||||
add_library(${LIBND4J_NAME} SHARED $<TARGET_OBJECTS:nd4jobj>)
|
add_library(${LIBND4J_NAME} SHARED $<TARGET_OBJECTS:nd4jobj>)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
target_link_libraries(${LIBND4J_NAME} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES})
|
# we're including {MKLDNN} here in case of building from sources. in future that'll replace {MKLDNN_LIBRARIES}. same applies to BLAS
|
||||||
|
target_link_libraries(${LIBND4J_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||||
|
|
||||||
if ("${LIBND4J_ALL_OPS}" AND "${LIBND4J_BUILD_MINIFIER}")
|
if ("${LIBND4J_ALL_OPS}" AND "${LIBND4J_BUILD_MINIFIER}")
|
||||||
message(STATUS "Building minifier...")
|
message(STATUS "Building minifier...")
|
||||||
add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
|
add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
|
||||||
target_link_libraries(minifier ${LIBND4J_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES})
|
target_link_libraries(minifier ${LIBND4J_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
||||||
|
|
|
@ -1760,6 +1760,13 @@ ND4J_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc);
|
||||||
ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc);
|
ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc);
|
||||||
ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc);
|
ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc);
|
||||||
|
|
||||||
|
|
||||||
|
ND4J_EXPORT int binaryLevel();
|
||||||
|
ND4J_EXPORT int optimalLevel();
|
||||||
|
|
||||||
|
ND4J_EXPORT bool isMinimalRequirementsMet();
|
||||||
|
ND4J_EXPORT bool isOptimalRequirementsMet();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif //NATIVEOPERATIONS_NATIVEOPS_H
|
#endif //NATIVEOPERATIONS_NATIVEOPS_H
|
||||||
|
|
|
@ -76,6 +76,10 @@ bool experimentalSupport = false;
|
||||||
#include <performance/benchmarking/FullBenchmarkSuit.h>
|
#include <performance/benchmarking/FullBenchmarkSuit.h>
|
||||||
#include <performance/benchmarking/LightBenchmarkSuit.h>
|
#include <performance/benchmarking/LightBenchmarkSuit.h>
|
||||||
|
|
||||||
|
#ifdef CPU_FEATURES
|
||||||
|
#include <cpuinfo_x86.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
using namespace nd4j;
|
using namespace nd4j;
|
||||||
|
|
||||||
void setElementThreshold(int num) {
|
void setElementThreshold(int num) {
|
||||||
|
@ -3167,6 +3171,75 @@ const char* lastErrorMessage() {
|
||||||
return nd4j::LaunchContext::defaultContext()->errorReference()->errorMessage();
|
return nd4j::LaunchContext::defaultContext()->errorReference()->errorMessage();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int binaryLevel() {
|
||||||
|
#ifdef CPU_FEATURES
|
||||||
|
|
||||||
|
|
||||||
|
#if defined(F_X64)
|
||||||
|
return 1;
|
||||||
|
#elif defined (F_AVX2)
|
||||||
|
return 2;
|
||||||
|
#elif defined (F_AVX512)
|
||||||
|
return 3;
|
||||||
|
#else
|
||||||
|
return 0;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#else
|
||||||
|
return 0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
int optimalLevel() {
|
||||||
|
#ifdef CPU_FEATURES
|
||||||
|
auto features = cpu_features::GetX86Info().features;
|
||||||
|
|
||||||
|
if (features.avx && features.avx2 && features.avx512f && features.avx512vl && features.avx512bw && features.avx512dq && features.avx512cd)
|
||||||
|
return 3;
|
||||||
|
else if (features.avx && features.avx2)
|
||||||
|
return 2;
|
||||||
|
else
|
||||||
|
return 1;
|
||||||
|
|
||||||
|
#else
|
||||||
|
return 0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isMinimalRequirementsMet() {
|
||||||
|
#ifdef CPU_FEATURES
|
||||||
|
auto features = cpu_features::GetX86Info().features;
|
||||||
|
|
||||||
|
#if defined(F_X64)
|
||||||
|
return true;
|
||||||
|
#elif defined (F_AVX2)
|
||||||
|
return features.avx && features.avx2;
|
||||||
|
#elif defined (F_AVX512)
|
||||||
|
// we're optimizing for skylake-avx512 features, so we'll check those out
|
||||||
|
return features.avx && features.avx2 && features.avx512f && features.avx512vl && features.avx512bw && features.avx512dq && features.avx512cd;
|
||||||
|
#else
|
||||||
|
return true;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#else
|
||||||
|
return true;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isOptimalRequirementsMet() {
|
||||||
|
#ifdef CPU_FEATURES
|
||||||
|
auto b = ::binaryLevel();
|
||||||
|
auto o = ::optimalLevel();
|
||||||
|
|
||||||
|
if (b == o)
|
||||||
|
return true;
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
#else
|
||||||
|
return true;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
||||||
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
||||||
BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, (void**, Nd4jLong**, void**, Nd4jLong**, int, int*, Nd4jLong**, Nd4jLong**), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, (void**, Nd4jLong**, void**, Nd4jLong**, int, int*, Nd4jLong**, Nd4jLong**), LIBND4J_TYPES);
|
||||||
|
|
|
@ -3577,3 +3577,19 @@ int lastErrorCode() {
|
||||||
const char* lastErrorMessage() {
|
const char* lastErrorMessage() {
|
||||||
return nd4j::LaunchContext::defaultContext()->errorReference()->errorMessage();
|
return nd4j::LaunchContext::defaultContext()->errorReference()->errorMessage();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int binaryLevel() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int optimalLevel() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isMinimalRequirementsMet() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isOptimalRequirementsMet() {
|
||||||
|
return true;
|
||||||
|
}
|
|
@ -53,6 +53,7 @@ CLEAN="false"
|
||||||
MINIFIER="false"
|
MINIFIER="false"
|
||||||
TESTS="false"
|
TESTS="false"
|
||||||
VERBOSE="false"
|
VERBOSE="false"
|
||||||
|
HELPER=
|
||||||
NAME=
|
NAME=
|
||||||
while [[ $# > 0 ]]
|
while [[ $# > 0 ]]
|
||||||
do
|
do
|
||||||
|
@ -60,6 +61,10 @@ key="$1"
|
||||||
value="${2:-}"
|
value="${2:-}"
|
||||||
#Build type (release/debug), packaging type, chip: cpu,cuda, lib type (static/dynamic)
|
#Build type (release/debug), packaging type, chip: cpu,cuda, lib type (static/dynamic)
|
||||||
case $key in
|
case $key in
|
||||||
|
-h|--helper)
|
||||||
|
HELPER="$value"
|
||||||
|
shift # past argument
|
||||||
|
;;
|
||||||
-o|-platform|--platform)
|
-o|-platform|--platform)
|
||||||
OS="$value"
|
OS="$value"
|
||||||
shift # past argument
|
shift # past argument
|
||||||
|
@ -425,7 +430,7 @@ if [ "$PACKAGING" == "msi" ]; then
|
||||||
PACKAGING_ARG="-DPACKAGING=msi"
|
PACKAGING_ARG="-DPACKAGING=msi"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
EXPERIMENTAL_ARG="no";
|
EXPERIMENTAL_ARG="";
|
||||||
MINIFIER_ARG="-DLIBND4J_BUILD_MINIFIER=false"
|
MINIFIER_ARG="-DLIBND4J_BUILD_MINIFIER=false"
|
||||||
TESTS_ARG="-DBUILD_TESTS=OFF"
|
TESTS_ARG="-DBUILD_TESTS=OFF"
|
||||||
NAME_ARG="-DLIBND4J_NAME=$NAME"
|
NAME_ARG="-DLIBND4J_NAME=$NAME"
|
||||||
|
@ -461,16 +466,12 @@ if [ "$CHIP" == "cuda" ] && [ -n "$CHIP_VERSION" ]; then
|
||||||
esac
|
esac
|
||||||
fi
|
fi
|
||||||
|
|
||||||
[[ -z ${MKLDNN_PATH:-} ]] && MKLDNN_PATH=""
|
|
||||||
[[ -z ${OPENBLAS_PATH:-} ]] && OPENBLAS_PATH=""
|
[[ -z ${OPENBLAS_PATH:-} ]] && OPENBLAS_PATH=""
|
||||||
|
|
||||||
if [[ -n "${BUILD_PATH:-}" ]]; then
|
if [[ -n "${BUILD_PATH:-}" ]]; then
|
||||||
PREVIFS="$IFS"
|
PREVIFS="$IFS"
|
||||||
IFS="$BUILD_PATH_SEPARATOR"
|
IFS="$BUILD_PATH_SEPARATOR"
|
||||||
for P in $BUILD_PATH; do
|
for P in $BUILD_PATH; do
|
||||||
if [[ -f "$P/include/mkldnn.h" ]]; then
|
|
||||||
MKLDNN_PATH="$P"
|
|
||||||
fi
|
|
||||||
if [[ -f "$P/include/openblas_config.h" ]]; then
|
if [[ -f "$P/include/openblas_config.h" ]]; then
|
||||||
OPENBLAS_PATH="$P"
|
OPENBLAS_PATH="$P"
|
||||||
fi
|
fi
|
||||||
|
@ -478,18 +479,12 @@ if [[ -n "${BUILD_PATH:-}" ]]; then
|
||||||
IFS="$PREVIFS"
|
IFS="$PREVIFS"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ ! -f "$MKLDNN_PATH/include/mkldnn.h" ]]; then
|
|
||||||
echo "Could not find MKL-DNN, please make sure to run the build with Maven or set the MKLDNN_PATH variable"
|
|
||||||
MKLDNN_PATH=""
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [[ ! -f "$OPENBLAS_PATH/include/openblas_config.h" ]]; then
|
if [[ ! -f "$OPENBLAS_PATH/include/openblas_config.h" ]]; then
|
||||||
echo "Could not find OpenBLAS, please make sure to run the build with Maven or set the OPENBLAS_PATH variable"
|
echo "Could not find OpenBLAS, please make sure to run the build with Maven or set the OPENBLAS_PATH variable"
|
||||||
OPENBLAS_PATH=""
|
OPENBLAS_PATH=""
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# replace any backslash with a slash
|
# replace any backslash with a slash
|
||||||
MKLDNN_PATH="${MKLDNN_PATH//\\//}"
|
|
||||||
OPENBLAS_PATH="${OPENBLAS_PATH//\\//}"
|
OPENBLAS_PATH="${OPENBLAS_PATH//\\//}"
|
||||||
|
|
||||||
mkbuilddir() {
|
mkbuilddir() {
|
||||||
|
@ -501,6 +496,21 @@ mkbuilddir() {
|
||||||
cd "blasbuild/$CHIP"
|
cd "blasbuild/$CHIP"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if [ "$HELPER" == "" ]; then
|
||||||
|
echo "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
|
||||||
|
echo "!! !!"
|
||||||
|
echo "!! !!"
|
||||||
|
echo "!! !!"
|
||||||
|
echo "!! !!"
|
||||||
|
echo "!! WARNING! !!"
|
||||||
|
echo "!! No helper packages configured! !!"
|
||||||
|
echo "!! You can specify helper by using -h key. I.e. <-h mkldnn> !!"
|
||||||
|
echo "!! !!"
|
||||||
|
echo "!! !!"
|
||||||
|
echo "!! !!"
|
||||||
|
echo "!! !!"
|
||||||
|
echo "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
|
||||||
|
fi
|
||||||
|
|
||||||
echo PACKAGING = "${PACKAGING}"
|
echo PACKAGING = "${PACKAGING}"
|
||||||
echo BUILD = "${BUILD}"
|
echo BUILD = "${BUILD}"
|
||||||
|
@ -515,11 +525,11 @@ echo OPERATIONS = "${OPERATIONS_ARG}"
|
||||||
echo MINIFIER = "${MINIFIER_ARG}"
|
echo MINIFIER = "${MINIFIER_ARG}"
|
||||||
echo TESTS = "${TESTS_ARG}"
|
echo TESTS = "${TESTS_ARG}"
|
||||||
echo NAME = "${NAME_ARG}"
|
echo NAME = "${NAME_ARG}"
|
||||||
echo MKLDNN_PATH = "$MKLDNN_PATH"
|
|
||||||
echo OPENBLAS_PATH = "$OPENBLAS_PATH"
|
echo OPENBLAS_PATH = "$OPENBLAS_PATH"
|
||||||
|
echo HELPERS = "$HELPER"
|
||||||
mkbuilddir
|
mkbuilddir
|
||||||
pwd
|
pwd
|
||||||
eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DMKLDNN_PATH="$MKLDNN_PATH" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../..
|
eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" -DHELPERS_"$HELPER"=true "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../..
|
||||||
if [ "$PARALLEL" == "true" ]; then
|
if [ "$PARALLEL" == "true" ]; then
|
||||||
MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ"
|
MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ"
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -223,7 +223,7 @@ DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory:
|
||||||
|
|
||||||
setCountersToZero();
|
setCountersToZero();
|
||||||
|
|
||||||
if(!lenInBytes == 0) {
|
if(lenInBytes != 0) {
|
||||||
allocateBuffers(allocBoth);
|
allocateBuffers(allocBoth);
|
||||||
writeSpecial();
|
writeSpecial();
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,9 +31,10 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
#ifdef HAVE_MKLDNN
|
||||||
|
// FIXME: latest mkldnn doesn't ship mklml anymore?
|
||||||
// include CBLAS from MKL-DNN
|
// include CBLAS from MKL-DNN
|
||||||
#include <mkl_cblas.h>
|
//#include <mkl_cblas.h>
|
||||||
#define CBLAS_H
|
//#define CBLAS_H
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef HAVE_OPENBLAS
|
#ifdef HAVE_OPENBLAS
|
||||||
|
|
|
@ -29,6 +29,10 @@
|
||||||
#include <cuda_device_runtime_api.h>
|
#include <cuda_device_runtime_api.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// used for MKLDNN etc
|
||||||
|
#if !defined(__STANDALONE_BUILD__)
|
||||||
|
#include "config.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <dll.h>
|
#include <dll.h>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -49,6 +53,9 @@ class ND4J_EXPORT LaunchContext {
|
||||||
static std::vector<std::shared_ptr<LaunchContext>> _contexts;
|
static std::vector<std::shared_ptr<LaunchContext>> _contexts;
|
||||||
static std::mutex _mutex;
|
static std::mutex _mutex;
|
||||||
|
|
||||||
|
// used for MKLDNN
|
||||||
|
void *_engine = nullptr;
|
||||||
|
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
|
|
||||||
#ifndef __JAVACPP_HACK__
|
#ifndef __JAVACPP_HACK__
|
||||||
|
@ -96,6 +103,8 @@ class ND4J_EXPORT LaunchContext {
|
||||||
_workspace = theWorkspace;
|
_workspace = theWorkspace;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void* engine();
|
||||||
|
|
||||||
int getDeviceID() const {return _deviceID;}
|
int getDeviceID() const {return _deviceID;}
|
||||||
void setDeviceID(int deviceID) { _deviceID = deviceID; }
|
void setDeviceID(int deviceID) { _deviceID = deviceID; }
|
||||||
sd::ErrorReference* errorReference();
|
sd::ErrorReference* errorReference();
|
||||||
|
|
|
@ -29,10 +29,16 @@ nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
|
||||||
thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
|
thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef HAVE_MKLDNN
|
||||||
|
#include <mkldnn.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
|
||||||
LaunchContext::~LaunchContext() {
|
LaunchContext::~LaunchContext() {
|
||||||
|
#ifdef HAVE_MKLDNN
|
||||||
|
delete reinterpret_cast<mkldnn::engine*>(_engine);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
|
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
|
||||||
|
@ -42,6 +48,10 @@ namespace nd4j {
|
||||||
// default constructor, just to make clang/ranlib happy
|
// default constructor, just to make clang/ranlib happy
|
||||||
_workspace = nullptr;
|
_workspace = nullptr;
|
||||||
_deviceID = 0;
|
_deviceID = 0;
|
||||||
|
|
||||||
|
#ifdef HAVE_MKLDNN
|
||||||
|
_engine = new mkldnn::engine(mkldnn::engine::kind::cpu, 0);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
|
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
|
||||||
|
@ -73,4 +83,8 @@ namespace nd4j {
|
||||||
sd::ErrorReference* LaunchContext::errorReference() {
|
sd::ErrorReference* LaunchContext::errorReference() {
|
||||||
return contextBuffers.errorReference();
|
return contextBuffers.errorReference();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void* LaunchContext::engine() {
|
||||||
|
return _engine;
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -169,4 +169,8 @@ LaunchContext::LaunchContext() {
|
||||||
sd::ErrorReference* LaunchContext::errorReference() {
|
sd::ErrorReference* LaunchContext::errorReference() {
|
||||||
return contextBuffers.errorReference();
|
return contextBuffers.errorReference();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void* LaunchContext::engine() {
|
||||||
|
return _engine;
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -28,10 +28,6 @@
|
||||||
#include <graph/ContextPrototype.h>
|
#include <graph/ContextPrototype.h>
|
||||||
#include <memory/Workspace.h>
|
#include <memory/Workspace.h>
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
#include <MKLDNNStream.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// CUDA-specific includes
|
// CUDA-specific includes
|
||||||
#ifdef __CUDACC__
|
#ifdef __CUDACC__
|
||||||
|
|
||||||
|
@ -61,11 +57,6 @@ namespace nd4j {
|
||||||
LaunchContext* _context = nullptr;
|
LaunchContext* _context = nullptr;
|
||||||
|
|
||||||
std::vector<nd4j::DataType> _dataTypes;
|
std::vector<nd4j::DataType> _dataTypes;
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
std::vector<nd4j::MKLDNNStream> _mkldnnStreams;
|
|
||||||
#else
|
|
||||||
std::vector<Nd4jLong> _mkldnnStreams;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
std::vector<NDArray*> _fastpath_in;
|
std::vector<NDArray*> _fastpath_in;
|
||||||
std::vector<NDArray*> _fastpath_out;
|
std::vector<NDArray*> _fastpath_out;
|
||||||
|
@ -122,9 +113,6 @@ namespace nd4j {
|
||||||
int getBranch();
|
int getBranch();
|
||||||
void setBranch(int branch);
|
void setBranch(int branch);
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
std::vector<nd4j::MKLDNNStream>& getMKLDNNStreams() { return _mkldnnStreams; }
|
|
||||||
#endif
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @return
|
* @return
|
||||||
|
|
|
@ -98,9 +98,6 @@ namespace nd4j {
|
||||||
this->_inputs.clear();
|
this->_inputs.clear();
|
||||||
this->_fastpath_in.clear();
|
this->_fastpath_in.clear();
|
||||||
this->_fastpath_out.clear();
|
this->_fastpath_out.clear();
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
this->_mkldnnStreams.clear();
|
|
||||||
#endif
|
|
||||||
|
|
||||||
for (auto v:_handles)
|
for (auto v:_handles)
|
||||||
delete v;
|
delete v;
|
||||||
|
|
|
@ -21,12 +21,11 @@
|
||||||
#ifndef LIBND4J_MKLDNNSTREAM_H
|
#ifndef LIBND4J_MKLDNNSTREAM_H
|
||||||
#define LIBND4J_MKLDNNSTREAM_H
|
#define LIBND4J_MKLDNNSTREAM_H
|
||||||
|
|
||||||
#ifndef __STANDALONE_BUILD__
|
#if !defined(__STANDALONE_BUILD__)
|
||||||
#include "config.h"
|
#include "config.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
#if defined(HAVE_MKLDNN)
|
||||||
#include <mkldnn.hpp>
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
class MKLDNNStream {
|
class MKLDNNStream {
|
||||||
|
@ -38,26 +37,24 @@ namespace nd4j {
|
||||||
std::vector<float> _floatArguments;
|
std::vector<float> _floatArguments;
|
||||||
std::vector<int> _intArguments;
|
std::vector<int> _intArguments;
|
||||||
|
|
||||||
mkldnn::engine _engine = mkldnn::engine(mkldnn::engine::cpu, 0);
|
|
||||||
std::vector<mkldnn::memory> _memory;
|
|
||||||
std::vector<mkldnn::primitive> _operations;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
static bool isSupported() {
|
static bool isSupported() {
|
||||||
|
// FIXME: strict float support doesn't work anymore
|
||||||
return typeid(X) == typeid(float) && typeid(Y) == typeid(float);
|
return typeid(X) == typeid(float) && typeid(Y) == typeid(float);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isSupported(const std::vector<const NDArray*> &arrays) {
|
static bool isSupported(const std::vector<const NDArray*> &arrays) {
|
||||||
for (auto i = arrays.begin(); i != arrays.end(); i++) {
|
// FIXME: strict float support doesn't work anymore
|
||||||
if (*i != nullptr && (*i)->dataType() != nd4j::DataType::FLOAT32) {
|
for (auto v:arrays) {
|
||||||
|
if (v != nullptr && v->dataType() != nd4j::DataType::FLOAT32) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
MKLDNNStream(const std::string &opName) : _opName(opName) { }
|
explicit MKLDNNStream(const std::string &opName) : _opName(opName) { }
|
||||||
|
|
||||||
bool checkAndReset(const std::vector<const NDArray*> &inputs, const std::vector<const NDArray*> &outputs,
|
bool checkAndReset(const std::vector<const NDArray*> &inputs, const std::vector<const NDArray*> &outputs,
|
||||||
const std::vector<float> &floatArguments, const std::vector<int> &intArguments) {
|
const std::vector<float> &floatArguments, const std::vector<int> &intArguments) {
|
||||||
|
@ -66,30 +63,10 @@ namespace nd4j {
|
||||||
_outputs = outputs;
|
_outputs = outputs;
|
||||||
_floatArguments = floatArguments;
|
_floatArguments = floatArguments;
|
||||||
_intArguments = intArguments;
|
_intArguments = intArguments;
|
||||||
_operations.clear();
|
|
||||||
_memory.clear();
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const mkldnn::engine &getEngine() { return _engine; }
|
|
||||||
void setEngine(const mkldnn::engine &engine) { _engine = engine; }
|
|
||||||
|
|
||||||
const std::vector<mkldnn::memory> &getMemory() { return _memory; }
|
|
||||||
void setMemory(const std::vector<mkldnn::memory> &memory) { _memory = memory; }
|
|
||||||
void addMemory(const mkldnn::memory &memory) { _memory.push_back(memory); }
|
|
||||||
|
|
||||||
const std::vector<mkldnn::primitive> &getOperations() { return _operations; }
|
|
||||||
void setOperations(const std::vector<mkldnn::primitive> &operations) { _operations = operations; }
|
|
||||||
void addOperation(const mkldnn::primitive &operation) { _operations.push_back(operation); }
|
|
||||||
|
|
||||||
bool submitAndWait(mkldnn::stream::kind kind = mkldnn::stream::kind::eager) {
|
|
||||||
nd4j_debug("Executing %s with MKL-DNN\n", _opName.c_str());
|
|
||||||
// need to create a new one because already executed streams become unusable
|
|
||||||
mkldnn::stream stream(kind);
|
|
||||||
return stream.submit(_operations).wait();
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -21,6 +21,8 @@
|
||||||
#include <helpers/OpTracker.h>
|
#include <helpers/OpTracker.h>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <helpers/logger.h>
|
#include <helpers/logger.h>
|
||||||
|
#include <NativeOps.h>
|
||||||
|
|
||||||
|
|
||||||
using namespace nd4j::ops;
|
using namespace nd4j::ops;
|
||||||
using namespace nd4j::graph;
|
using namespace nd4j::graph;
|
||||||
|
@ -35,6 +37,31 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void OpTracker::storeOperation(nd4j::graph::OpType opType, const OpDescriptor& descriptor) {
|
void OpTracker::storeOperation(nd4j::graph::OpType opType, const OpDescriptor& descriptor) {
|
||||||
|
// check out CPU features
|
||||||
|
if (!::isMinimalRequirementsMet()) {
|
||||||
|
|
||||||
|
auto binaryLevel = ::binaryLevel();
|
||||||
|
auto optimalLevel = ::optimalLevel();
|
||||||
|
|
||||||
|
switch (binaryLevel) {
|
||||||
|
case 3: {
|
||||||
|
nd4j_printf("libnd4j binary was built with AVX512 support, but current CPU doesn't have this instruction set. Exiting now...","");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 2: {
|
||||||
|
nd4j_printf("libnd4j binary was built with AVX/AVX2 support, but current CPU doesn't have this instruction set. Exiting now...","");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default: {
|
||||||
|
nd4j_printf("Unknown binary validation error. Exiting now...","");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// we're exiting now
|
||||||
|
exit(119);
|
||||||
|
}
|
||||||
|
//
|
||||||
if (_map.count(opType) < 1) {
|
if (_map.count(opType) < 1) {
|
||||||
std::vector<OpDescriptor> vec;
|
std::vector<OpDescriptor> vec;
|
||||||
_map[opType] = vec;
|
_map[opType] = vec;
|
||||||
|
|
|
@ -4417,8 +4417,10 @@ INLINEDEF void calcOffsets(const Nd4jLong* shapeInfo, Nd4jLong* offsets, const c
|
||||||
if(order == shape::order(shapeInfo) || e == 1) { // e==1 means common vector
|
if(order == shape::order(shapeInfo) || e == 1) { // e==1 means common vector
|
||||||
e = 1;
|
e = 1;
|
||||||
Nd4jLong len = shape::length(shapeInfo);
|
Nd4jLong len = shape::length(shapeInfo);
|
||||||
while(e < len)
|
while(e < len) {
|
||||||
offsets[e++] = offsets[e - 1] + ews;
|
offsets[e] = offsets[e - 1] + ews;
|
||||||
|
e++;
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4464,8 +4466,10 @@ INLINEDEF void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong
|
||||||
if(shape[j] == 1) { --j; continue; } // ignore dimensions equal to unity
|
if(shape[j] == 1) { --j; continue; } // ignore dimensions equal to unity
|
||||||
|
|
||||||
if(j == rankMinusOne) { // last dimension
|
if(j == rankMinusOne) { // last dimension
|
||||||
for(int l = 1; l < shape[j]; ++l)
|
for(int l = 1; l < shape[j]; ++l) {
|
||||||
offsets[i++] = offsets[i - 1] + strides[j];
|
offsets[i] = offsets[i - 1] + strides[j];
|
||||||
|
i++;
|
||||||
|
}
|
||||||
--j;
|
--j;
|
||||||
}
|
}
|
||||||
else if(idx[j] < shape[j] - 1) {
|
else if(idx[j] < shape[j] - 1) {
|
||||||
|
@ -4489,8 +4493,10 @@ INLINEDEF void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong
|
||||||
if(shape[j] == 1) { ++j; continue; } // ignore dimensions equal to unity
|
if(shape[j] == 1) { ++j; continue; } // ignore dimensions equal to unity
|
||||||
|
|
||||||
if(j == 0) { // last dimension
|
if(j == 0) { // last dimension
|
||||||
for(int l = 1; l < shape[j]; ++l)
|
for(int l = 1; l < shape[j]; ++l) {
|
||||||
offsets[i++] = offsets[i - 1] + strides[j];
|
offsets[i] = offsets[i - 1] + strides[j];
|
||||||
|
i++;
|
||||||
|
}
|
||||||
++j;
|
++j;
|
||||||
}
|
}
|
||||||
else if(idx[j] < shape[j] - 1) {
|
else if(idx[j] < shape[j] - 1) {
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace nd4j {
|
||||||
OpDescriptor * _descriptor;
|
OpDescriptor * _descriptor;
|
||||||
|
|
||||||
bool prepareOutputs(Context& block);
|
bool prepareOutputs(Context& block);
|
||||||
virtual Nd4jStatus validateAndExecute(Context& block) = 0;
|
Nd4jStatus validateAndExecute(Context& block) override = 0;
|
||||||
public:
|
public:
|
||||||
BooleanOp(const char *name, int numInputs, bool scalar);
|
BooleanOp(const char *name, int numInputs, bool scalar);
|
||||||
~BooleanOp();
|
~BooleanOp();
|
||||||
|
|
|
@ -30,7 +30,7 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
class ND4J_EXPORT BroadcastableOp : public DeclarableCustomOp{
|
class ND4J_EXPORT BroadcastableOp : public DeclarableCustomOp{
|
||||||
protected:
|
protected:
|
||||||
virtual Nd4jStatus validateAndExecute(Context& block) = 0;
|
Nd4jStatus validateAndExecute(Context& block) override = 0;
|
||||||
public:
|
public:
|
||||||
BroadcastableOp(const char *name, int numTArgs, int numIArgs);
|
BroadcastableOp(const char *name, int numTArgs, int numIArgs);
|
||||||
~BroadcastableOp();
|
~BroadcastableOp();
|
||||||
|
|
|
@ -30,12 +30,12 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* This method executes this Op
|
* This method executes this Op
|
||||||
*/
|
*/
|
||||||
virtual Nd4jStatus validateAndExecute(Context& block) = 0;
|
Nd4jStatus validateAndExecute(Context& block) override = 0;
|
||||||
public:
|
public:
|
||||||
DeclarableCustomOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs);
|
DeclarableCustomOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs);
|
||||||
~DeclarableCustomOp();
|
~DeclarableCustomOp();
|
||||||
|
|
||||||
virtual ShapeList* calculateOutputShape(ShapeList* inputShapes, nd4j::graph::Context& block) = 0;
|
ShapeList* calculateOutputShape(ShapeList* inputShapes, nd4j::graph::Context& block) override = 0;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
class ND4J_EXPORT DeclarableListOp : public nd4j::ops::DeclarableOp {
|
class ND4J_EXPORT DeclarableListOp : public nd4j::ops::DeclarableOp {
|
||||||
protected:
|
protected:
|
||||||
virtual Nd4jStatus validateAndExecute(Context& block) = 0;
|
Nd4jStatus validateAndExecute(Context& block) override = 0;
|
||||||
|
|
||||||
nd4j::NDArray* getZ(Context& block, int inputId);
|
nd4j::NDArray* getZ(Context& block, int inputId);
|
||||||
void setupResult(NDArray* array, Context& block);
|
void setupResult(NDArray* array, Context& block);
|
||||||
|
|
|
@ -30,12 +30,12 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* This method executes this Op
|
* This method executes this Op
|
||||||
*/
|
*/
|
||||||
virtual Nd4jStatus validateAndExecute(Context& block) = 0;
|
Nd4jStatus validateAndExecute(Context& block) override = 0;
|
||||||
public:
|
public:
|
||||||
DeclarableReductionOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs);
|
DeclarableReductionOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs);
|
||||||
~DeclarableReductionOp();
|
~DeclarableReductionOp();
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,13 +30,13 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyBroadcastBoolOp : public LegacyOp {
|
class ND4J_EXPORT LegacyBroadcastBoolOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context& block);
|
Nd4jStatus validateAndExecute(Context& block) override ;
|
||||||
public:
|
public:
|
||||||
LegacyBroadcastBoolOp();
|
LegacyBroadcastBoolOp();
|
||||||
LegacyBroadcastBoolOp(int opNum);
|
LegacyBroadcastBoolOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,13 +30,13 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyBroadcastOp : public LegacyOp {
|
class ND4J_EXPORT LegacyBroadcastOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context& block);
|
Nd4jStatus validateAndExecute(Context& block) override;
|
||||||
public:
|
public:
|
||||||
LegacyBroadcastOp();
|
LegacyBroadcastOp();
|
||||||
LegacyBroadcastOp(int opNum);
|
LegacyBroadcastOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,13 +32,13 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyIndexReduceOp : public LegacyOp {
|
class ND4J_EXPORT LegacyIndexReduceOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context& block);
|
Nd4jStatus validateAndExecute(Context& block) override;
|
||||||
public:
|
public:
|
||||||
LegacyIndexReduceOp();
|
LegacyIndexReduceOp();
|
||||||
LegacyIndexReduceOp(int opNum);
|
LegacyIndexReduceOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,13 +41,13 @@ namespace nd4j {
|
||||||
int _numInputs = 0;
|
int _numInputs = 0;
|
||||||
|
|
||||||
// All Op classes provide own specific implementation for this method
|
// All Op classes provide own specific implementation for this method
|
||||||
virtual Nd4jStatus validateAndExecute(Context& block) = 0;
|
Nd4jStatus validateAndExecute(Context& block) override = 0;
|
||||||
public:
|
public:
|
||||||
LegacyOp(int numInputs);
|
LegacyOp(int numInputs);
|
||||||
LegacyOp(int numInputs, int opNum);
|
LegacyOp(int numInputs, int opNum);
|
||||||
|
|
||||||
// All Op classes provide own specific implementation for this method
|
// All Op classes provide own specific implementation for this method
|
||||||
virtual ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) = 0;
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override = 0;
|
||||||
virtual LegacyOp* clone() = 0;
|
virtual LegacyOp* clone() = 0;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,13 +30,13 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyPairwiseTransformBoolOp: public LegacyOp {
|
class ND4J_EXPORT LegacyPairwiseTransformBoolOp: public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context& block);
|
Nd4jStatus validateAndExecute(Context& block) override;
|
||||||
public:
|
public:
|
||||||
LegacyPairwiseTransformBoolOp();
|
LegacyPairwiseTransformBoolOp();
|
||||||
LegacyPairwiseTransformBoolOp(int opNum);
|
LegacyPairwiseTransformBoolOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,13 +30,13 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyPairwiseTransformOp: public LegacyOp {
|
class ND4J_EXPORT LegacyPairwiseTransformOp: public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context& block);
|
Nd4jStatus validateAndExecute(Context& block) override;
|
||||||
public:
|
public:
|
||||||
LegacyPairwiseTransformOp();
|
LegacyPairwiseTransformOp();
|
||||||
LegacyPairwiseTransformOp(int opNum);
|
LegacyPairwiseTransformOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyRandomOp : public LegacyOp {
|
class ND4J_EXPORT LegacyRandomOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context& block);
|
Nd4jStatus validateAndExecute(Context& block) override;
|
||||||
public:
|
public:
|
||||||
LegacyRandomOp();
|
LegacyRandomOp();
|
||||||
LegacyRandomOp(int opNum);
|
LegacyRandomOp(int opNum);
|
||||||
|
@ -43,10 +43,10 @@ namespace nd4j {
|
||||||
|
|
||||||
nd4j::ResultSet* execute(nd4j::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace = false);
|
nd4j::ResultSet* execute(nd4j::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace = false);
|
||||||
nd4j::ResultSet* execute(nd4j::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace = false);
|
nd4j::ResultSet* execute(nd4j::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace = false);
|
||||||
Nd4jStatus execute(Context* block);
|
Nd4jStatus execute(Context* block) override;
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,13 +30,13 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyReduce3Op : public LegacyOp {
|
class ND4J_EXPORT LegacyReduce3Op : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context& block);
|
Nd4jStatus validateAndExecute(Context& block) override;
|
||||||
public:
|
public:
|
||||||
LegacyReduce3Op();
|
LegacyReduce3Op();
|
||||||
LegacyReduce3Op(int opNum);
|
LegacyReduce3Op(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,13 +27,13 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
class ND4J_EXPORT LegacyReduceBoolOp : public LegacyOp {
|
class ND4J_EXPORT LegacyReduceBoolOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context& block);
|
Nd4jStatus validateAndExecute(Context& block) override;
|
||||||
public:
|
public:
|
||||||
LegacyReduceBoolOp();
|
LegacyReduceBoolOp();
|
||||||
LegacyReduceBoolOp(int opNum);
|
LegacyReduceBoolOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,13 +27,13 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
class ND4J_EXPORT LegacyReduceFloatOp : public LegacyOp {
|
class ND4J_EXPORT LegacyReduceFloatOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context& block);
|
Nd4jStatus validateAndExecute(Context& block) override;
|
||||||
public:
|
public:
|
||||||
LegacyReduceFloatOp();
|
LegacyReduceFloatOp();
|
||||||
LegacyReduceFloatOp(int opNum);
|
LegacyReduceFloatOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,13 +27,13 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
class ND4J_EXPORT LegacyReduceLongOp : public LegacyOp {
|
class ND4J_EXPORT LegacyReduceLongOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context& block);
|
Nd4jStatus validateAndExecute(Context& block) override;
|
||||||
public:
|
public:
|
||||||
LegacyReduceLongOp();
|
LegacyReduceLongOp();
|
||||||
LegacyReduceLongOp(int opNum);
|
LegacyReduceLongOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,13 +27,13 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
class ND4J_EXPORT LegacyReduceSameOp: public LegacyOp {
|
class ND4J_EXPORT LegacyReduceSameOp: public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context& block);
|
Nd4jStatus validateAndExecute(Context& block) override;
|
||||||
public:
|
public:
|
||||||
LegacyReduceSameOp();
|
LegacyReduceSameOp();
|
||||||
LegacyReduceSameOp(int opNum);
|
LegacyReduceSameOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,15 +30,15 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyScalarBoolOp : public LegacyOp {
|
class ND4J_EXPORT LegacyScalarBoolOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context& block);
|
Nd4jStatus validateAndExecute(Context& block) override;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LegacyScalarBoolOp();
|
LegacyScalarBoolOp();
|
||||||
LegacyScalarBoolOp(int opNum);
|
LegacyScalarBoolOp(int opNum);
|
||||||
LegacyScalarBoolOp(int opNum, NDArray &scalar);
|
LegacyScalarBoolOp(int opNum, NDArray &scalar);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,15 +30,15 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyScalarOp : public LegacyOp {
|
class ND4J_EXPORT LegacyScalarOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context& block);
|
Nd4jStatus validateAndExecute(Context& block) override;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LegacyScalarOp();
|
LegacyScalarOp();
|
||||||
LegacyScalarOp(int opNum);
|
LegacyScalarOp(int opNum);
|
||||||
LegacyScalarOp(int opNum, NDArray &scalar);
|
LegacyScalarOp(int opNum, NDArray &scalar);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,13 +30,13 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyStatsOp : public LegacyOp {
|
class ND4J_EXPORT LegacyStatsOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context &block);
|
Nd4jStatus validateAndExecute(Context &block) override;
|
||||||
public:
|
public:
|
||||||
LegacyStatsOp();
|
LegacyStatsOp();
|
||||||
LegacyStatsOp(int opNum);
|
LegacyStatsOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,13 +31,13 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyTransformAnyOp : public LegacyOp {
|
class ND4J_EXPORT LegacyTransformAnyOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context &block);
|
Nd4jStatus validateAndExecute(Context &block) override;
|
||||||
public:
|
public:
|
||||||
LegacyTransformAnyOp();
|
LegacyTransformAnyOp();
|
||||||
LegacyTransformAnyOp(int opNum);
|
LegacyTransformAnyOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,13 +32,13 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyTransformBoolOp : public LegacyOp {
|
class ND4J_EXPORT LegacyTransformBoolOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context &block);
|
Nd4jStatus validateAndExecute(Context &block) override;
|
||||||
public:
|
public:
|
||||||
LegacyTransformBoolOp();
|
LegacyTransformBoolOp();
|
||||||
LegacyTransformBoolOp(int opNum);
|
LegacyTransformBoolOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,13 +31,13 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyTransformFloatOp : public LegacyOp {
|
class ND4J_EXPORT LegacyTransformFloatOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context &block);
|
Nd4jStatus validateAndExecute(Context &block) override;
|
||||||
public:
|
public:
|
||||||
LegacyTransformFloatOp();
|
LegacyTransformFloatOp();
|
||||||
LegacyTransformFloatOp(int opNum);
|
LegacyTransformFloatOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,13 +32,13 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyTransformSameOp : public LegacyOp {
|
class ND4J_EXPORT LegacyTransformSameOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context &block);
|
Nd4jStatus validateAndExecute(Context &block) override;
|
||||||
public:
|
public:
|
||||||
LegacyTransformSameOp();
|
LegacyTransformSameOp();
|
||||||
LegacyTransformSameOp(int opNum);
|
LegacyTransformSameOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,13 +32,13 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
class ND4J_EXPORT LegacyTransformStrictOp : public LegacyOp {
|
class ND4J_EXPORT LegacyTransformStrictOp : public LegacyOp {
|
||||||
protected:
|
protected:
|
||||||
Nd4jStatus validateAndExecute(Context &block);
|
Nd4jStatus validateAndExecute(Context &block) override;
|
||||||
public:
|
public:
|
||||||
LegacyTransformStrictOp();
|
LegacyTransformStrictOp();
|
||||||
LegacyTransformStrictOp(int opNum);
|
LegacyTransformStrictOp(int opNum);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block);
|
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block) override;
|
||||||
virtual LegacyOp* clone();
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <ops/declarable/DeclarableOp.h>
|
#include <ops/declarable/DeclarableOp.h>
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
|
||||||
// handlers part
|
// handlers part
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
@ -59,10 +60,16 @@ namespace nd4j {
|
||||||
|
|
||||||
std::map<Nd4jLong, std::string> _msvc;
|
std::map<Nd4jLong, std::string> _msvc;
|
||||||
|
|
||||||
|
// pointers to our operations
|
||||||
std::map<Nd4jLong, nd4j::ops::DeclarableOp*> _declarablesLD;
|
std::map<Nd4jLong, nd4j::ops::DeclarableOp*> _declarablesLD;
|
||||||
std::map<std::string, nd4j::ops::DeclarableOp*> _declarablesD;
|
std::map<std::string, nd4j::ops::DeclarableOp*> _declarablesD;
|
||||||
std::vector<nd4j::ops::DeclarableOp *> _uniqueD;
|
std::vector<nd4j::ops::DeclarableOp *> _uniqueD;
|
||||||
|
|
||||||
|
// pointers to platform-specific helpers
|
||||||
|
std::map<Nd4jLong, nd4j::ops::platforms::PlatformHelper*> _helpersLH;
|
||||||
|
std::map<std::string, nd4j::ops::platforms::PlatformHelper*> _helpersH;
|
||||||
|
std::vector<nd4j::ops::platforms::PlatformHelper*> _uniqueH;
|
||||||
|
|
||||||
std::mutex _locker;
|
std::mutex _locker;
|
||||||
std::string _opsList;
|
std::string _opsList;
|
||||||
bool isInit = false;
|
bool isInit = false;
|
||||||
|
@ -82,17 +89,23 @@ namespace nd4j {
|
||||||
const char * getAllCustomOperations();
|
const char * getAllCustomOperations();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method registers operation
|
* This method registers operation in our registry, so we can use them later
|
||||||
*
|
*
|
||||||
* @param op
|
* @param op
|
||||||
*/
|
*/
|
||||||
bool registerOperation(const char* name, nd4j::ops::DeclarableOp* op);
|
bool registerOperation(const char* name, nd4j::ops::DeclarableOp* op);
|
||||||
bool registerOperation(nd4j::ops::DeclarableOp *op);
|
bool registerOperation(nd4j::ops::DeclarableOp *op);
|
||||||
|
|
||||||
|
void registerHelper(nd4j::ops::platforms::PlatformHelper* op);
|
||||||
|
|
||||||
|
bool hasHelper(Nd4jLong hash);
|
||||||
|
|
||||||
nd4j::ops::DeclarableOp* getOperation(const char *name);
|
nd4j::ops::DeclarableOp* getOperation(const char *name);
|
||||||
nd4j::ops::DeclarableOp* getOperation(Nd4jLong hash);
|
nd4j::ops::DeclarableOp* getOperation(Nd4jLong hash);
|
||||||
nd4j::ops::DeclarableOp* getOperation(std::string &name);
|
nd4j::ops::DeclarableOp* getOperation(std::string &name);
|
||||||
|
|
||||||
|
nd4j::ops::platforms::PlatformHelper* getPlatformHelper(Nd4jLong hash);
|
||||||
|
|
||||||
std::vector<Nd4jLong> getAllHashes();
|
std::vector<Nd4jLong> getAllHashes();
|
||||||
|
|
||||||
int numberOfOperations();
|
int numberOfOperations();
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SD_PLATFORMHELPER_H
|
||||||
|
#define SD_PLATFORMHELPER_H
|
||||||
|
|
||||||
|
#include <ShapeUtils.h>
|
||||||
|
#include <graph/Context.h>
|
||||||
|
#include <string>
|
||||||
|
#include <pointercast.h>
|
||||||
|
#include <dll.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
/**
|
||||||
|
* This abstract class defines methods used by platform-specific helpers implementations
|
||||||
|
*/
|
||||||
|
class ND4J_EXPORT PlatformHelper {
|
||||||
|
protected:
|
||||||
|
// name of the operation this helper is built for
|
||||||
|
std::string _name;
|
||||||
|
|
||||||
|
// hash of the operation this helper is built for
|
||||||
|
Nd4jLong _hash;
|
||||||
|
public:
|
||||||
|
PlatformHelper(const char *name);
|
||||||
|
|
||||||
|
~PlatformHelper() = default;
|
||||||
|
|
||||||
|
std::string name();
|
||||||
|
|
||||||
|
Nd4jLong hash();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method checks, if given helper can be used with given input/output/configuration options
|
||||||
|
*
|
||||||
|
* @param context
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
virtual bool isUsable(graph::Context &context) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method invokes helper. Typically this method replaces actual op execution
|
||||||
|
*
|
||||||
|
* @param context
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
virtual Nd4jStatus invokeHelper(graph::Context &context) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper method, needed for compatibility with DeclarableOp macros
|
||||||
|
* @param ctx
|
||||||
|
* @param inputId
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
nd4j::NDArray *getZ(graph::Context &ctx, int inputId);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //SD_PLATFORMHELPER_H
|
|
@ -28,54 +28,6 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
using namespace mkldnn;
|
|
||||||
|
|
||||||
static void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
|
||||||
mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
|
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
|
|
||||||
const Nd4jLong* shape = src->getShapeInfo();
|
|
||||||
Nd4jLong rank = shape[0];
|
|
||||||
Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
|
||||||
Nd4jLong dim2 = axis >= 2 ? 1 : 2;
|
|
||||||
Nd4jLong dim3 = axis >= 3 ? 2 : 3;
|
|
||||||
mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
|
||||||
|
|
||||||
auto type = mkldnn::memory::data_type::f32;
|
|
||||||
auto format = mkldnn::memory::format::nchw;
|
|
||||||
auto supposed_to_be_any_format = mkldnn::memory::format::nChw8c; // doesn't work with "any"
|
|
||||||
|
|
||||||
if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
|
|
||||||
*batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
|
||||||
user_src_md->data.format = mkldnn_blocked; // overrides format
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[0];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[dim1];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
|
|
||||||
*batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
|
||||||
user_diff_src_md->data.format = mkldnn_blocked; // overrides format
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[0];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[dim1];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
|
|
||||||
*batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
|
||||||
user_dst_md->data.format = mkldnn_blocked; // overrides format
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[0];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[dim1];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) {
|
CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
@ -208,84 +160,6 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) {
|
||||||
for(int i = 1; i < block.width(); ++i)
|
for(int i = 1; i < block.width(); ++i)
|
||||||
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_NEW op: types of all input arrays should be the same !");
|
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_NEW op: types of all input arrays should be the same !");
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) && numOfAxes == 1) {
|
|
||||||
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
||||||
if (streams.empty()) {
|
|
||||||
streams.push_back(MKLDNNStream("batchnorm_new"));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shape({2, mean->lengthOf()});
|
|
||||||
NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext());
|
|
||||||
weights({0, 1, 0, 0}).assign(1.0f);
|
|
||||||
weights({1, 2, 0, 0}).assign(0.0f);
|
|
||||||
|
|
||||||
if (streams[0].checkAndReset({input, mean, variance, gamma, beta}, {output}, {(float)epsilon}, axes)) {
|
|
||||||
mkldnn_memory_desc_t empty;
|
|
||||||
mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty);
|
|
||||||
|
|
||||||
getMKLDNNMemoryDescBatchNorm(input, nullptr, output,
|
|
||||||
&batchnorm_src_md, nullptr, &batchnorm_dst_md,
|
|
||||||
&user_src_md, nullptr, &user_dst_md, axes[0]);
|
|
||||||
|
|
||||||
auto batchnorm_desc = batch_normalization_forward::desc(prop_kind::forward_inference, batchnorm_src_md, epsilon,
|
|
||||||
use_global_stats | (applyScale || applyOffset ? use_scale_shift : 0));
|
|
||||||
|
|
||||||
auto engine = streams[0].getEngine();
|
|
||||||
auto batchnorm_prim_desc = batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
|
|
||||||
auto user_src_memory = mkldnn::memory({user_src_md, engine}, input->buffer());
|
|
||||||
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output->buffer());
|
|
||||||
auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_primitive_desc(), mean->buffer());
|
|
||||||
auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_primitive_desc(), variance->buffer());
|
|
||||||
|
|
||||||
auto batchnorm_src_memory = user_src_memory;
|
|
||||||
streams[0].addMemory(user_src_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc({batchnorm_src_md, engine})
|
|
||||||
!= user_src_memory.get_primitive_desc()) {
|
|
||||||
batchnorm_src_memory = mkldnn::memory({batchnorm_src_md, engine});
|
|
||||||
streams[0].addMemory(batchnorm_src_memory);
|
|
||||||
streams[0].addOperation(reorder(user_src_memory, batchnorm_src_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto batchnorm_dst_memory = user_dst_memory;
|
|
||||||
streams[0].addMemory(user_dst_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(batchnorm_prim_desc.dst_primitive_desc())
|
|
||||||
!= user_dst_memory.get_primitive_desc()) {
|
|
||||||
batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_primitive_desc());
|
|
||||||
streams[0].addMemory(batchnorm_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
streams[0].addMemory(batchnorm_mean_memory);
|
|
||||||
streams[0].addMemory(batchnorm_variance_memory);
|
|
||||||
|
|
||||||
if (applyScale || applyOffset) {
|
|
||||||
auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_primitive_desc(), weights.buffer());
|
|
||||||
streams[0].addMemory(batchnorm_weights_memory);
|
|
||||||
streams[0].addOperation(batch_normalization_forward(batchnorm_prim_desc, (mkldnn::primitive::at)batchnorm_src_memory,
|
|
||||||
(mkldnn::primitive::at)batchnorm_mean_memory, (mkldnn::primitive::at)batchnorm_variance_memory, (mkldnn::primitive::at)batchnorm_weights_memory, batchnorm_dst_memory));
|
|
||||||
} else {
|
|
||||||
streams[0].addOperation(batch_normalization_forward(batchnorm_prim_desc, (mkldnn::primitive::at)batchnorm_src_memory,
|
|
||||||
(mkldnn::primitive::at)batchnorm_mean_memory, (mkldnn::primitive::at)batchnorm_variance_memory, batchnorm_dst_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mkldnn::memory::primitive_desc(batchnorm_prim_desc.dst_primitive_desc())
|
|
||||||
!= user_dst_memory.get_primitive_desc()) {
|
|
||||||
streams[0].addOperation(reorder(batchnorm_dst_memory, user_dst_memory));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (applyScale || applyOffset) {
|
|
||||||
if (gamma != nullptr) {
|
|
||||||
weights({0, 1, 0, 0}).assign(gamma);
|
|
||||||
}
|
|
||||||
if (beta != nullptr) {
|
|
||||||
weights({1, 2, 0, 0}).assign(beta);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
streams[0].submitAndWait();
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
nd4j_debug("MKL-DNN is not used for batchnorm_new!\n", 0);
|
nd4j_debug("MKL-DNN is not used for batchnorm_new!\n", 0);
|
||||||
|
|
||||||
// formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
|
// formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
|
||||||
|
|
|
@ -29,12 +29,7 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
using namespace mkldnn;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
@ -70,83 +65,6 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // SAME
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output})) {
|
|
||||||
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
||||||
if (streams.empty()) {
|
|
||||||
streams.push_back(MKLDNNStream("conv3dnew"));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (streams[0].checkAndReset({input, weights, bias}, {output}, {}, {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, isNCDHW})) {
|
|
||||||
mkldnn_memory_desc_t empty;
|
|
||||||
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
|
||||||
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
|
||||||
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
|
|
||||||
|
|
||||||
ConvolutionUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, isNCDHW,
|
|
||||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights, nullptr, bias, output,
|
|
||||||
&conv_src_md, nullptr, &conv_weights_md, nullptr, &conv_bias_md, &conv_dst_md,
|
|
||||||
&user_src_md, nullptr, &user_weights_md, nullptr, &user_bias_md, &user_dst_md,
|
|
||||||
conv_strides, conv_padding, conv_padding_r);
|
|
||||||
|
|
||||||
auto conv_desc = bias != nullptr
|
|
||||||
? convolution_forward::desc(prop_kind::forward,
|
|
||||||
convolution_direct, conv_src_md, conv_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
|
|
||||||
: convolution_forward::desc(prop_kind::forward,
|
|
||||||
convolution_direct, conv_src_md, conv_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
|
|
||||||
|
|
||||||
auto engine = streams[0].getEngine();
|
|
||||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
|
|
||||||
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray*>(input)->buffer());
|
|
||||||
auto user_weights_memory = mkldnn::memory({user_weights_md, engine}, const_cast<NDArray*>(weights)->buffer());
|
|
||||||
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output->buffer());
|
|
||||||
|
|
||||||
auto conv_src_memory = user_src_memory;
|
|
||||||
streams[0].addMemory(user_src_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(conv_prim_desc.src_primitive_desc())
|
|
||||||
!= user_src_memory.get_primitive_desc()) {
|
|
||||||
conv_src_memory = mkldnn::memory(conv_prim_desc.src_primitive_desc());
|
|
||||||
streams[0].addMemory(conv_src_memory);
|
|
||||||
streams[0].addOperation(reorder(user_src_memory, conv_src_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto conv_weights_memory = user_weights_memory;
|
|
||||||
streams[0].addMemory(user_weights_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(conv_prim_desc.weights_primitive_desc())
|
|
||||||
!= user_weights_memory.get_primitive_desc()) {
|
|
||||||
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_primitive_desc());
|
|
||||||
streams[0].addMemory(conv_weights_memory);
|
|
||||||
streams[0].addOperation(reorder(user_weights_memory, conv_weights_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto conv_dst_memory = user_dst_memory;
|
|
||||||
streams[0].addMemory(user_dst_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(conv_prim_desc.dst_primitive_desc())
|
|
||||||
!= user_dst_memory.get_primitive_desc()) {
|
|
||||||
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_primitive_desc());
|
|
||||||
streams[0].addMemory(conv_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bias != nullptr) {
|
|
||||||
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_primitive_desc(), bias->buffer());
|
|
||||||
streams[0].addMemory(conv_bias_memory);
|
|
||||||
streams[0].addOperation(convolution_forward(conv_prim_desc, conv_src_memory, conv_weights_memory, conv_bias_memory, conv_dst_memory));
|
|
||||||
} else {
|
|
||||||
streams[0].addOperation(convolution_forward(conv_prim_desc, conv_src_memory, conv_weights_memory, conv_dst_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mkldnn::memory::primitive_desc(conv_prim_desc.dst_primitive_desc())
|
|
||||||
!= user_dst_memory.get_primitive_desc()) {
|
|
||||||
streams[0].addOperation(reorder(conv_dst_memory, user_dst_memory));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
streams[0].submitAndWait();
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
nd4j_debug("MKL-DNN is not used for conv3dnew!\n", 0);
|
nd4j_debug("MKL-DNN is not used for conv3dnew!\n", 0);
|
||||||
|
|
||||||
std::vector<int> permutForOutput;
|
std::vector<int> permutForOutput;
|
||||||
|
@ -297,151 +215,6 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // SAME
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB})) {
|
|
||||||
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
||||||
if (streams.empty()) {
|
|
||||||
streams.push_back(MKLDNNStream("conv3dnew_bp_weights"));
|
|
||||||
streams.push_back(MKLDNNStream("conv3dnew_bp_data"));
|
|
||||||
}
|
|
||||||
|
|
||||||
bool resetW = streams[0].checkAndReset({input, weights, bias, gradO}, {gradI, gradW, gradB}, {}, {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, isNDHWC});
|
|
||||||
bool resetI = streams[1].checkAndReset({input, weights, bias, gradO}, {gradI, gradW, gradB}, {}, {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, isNDHWC});
|
|
||||||
if (resetW || resetI) {
|
|
||||||
mkldnn_memory_desc_t empty;
|
|
||||||
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
|
||||||
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
|
||||||
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
|
||||||
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
|
||||||
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
|
|
||||||
|
|
||||||
ConvolutionUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, isNDHWC,
|
|
||||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights, gradW, gradB, gradO,
|
|
||||||
&conv_src_md, &conv_diff_src_md, &conv_weights_md, &conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
|
|
||||||
&user_src_md, &user_diff_src_md, &user_weights_md, &user_diff_weights_md, &user_bias_md, &user_dst_md,
|
|
||||||
conv_strides, conv_padding, conv_padding_r);
|
|
||||||
|
|
||||||
auto conv_desc = gradB != nullptr
|
|
||||||
? convolution_forward::desc(prop_kind::forward,
|
|
||||||
convolution_direct, conv_src_md, conv_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
|
|
||||||
: convolution_forward::desc(prop_kind::forward,
|
|
||||||
convolution_direct, conv_src_md, conv_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
|
|
||||||
|
|
||||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, streams[0].getEngine());
|
|
||||||
|
|
||||||
if (gradW != nullptr) {
|
|
||||||
auto convW_desc = gradB != nullptr
|
|
||||||
? convolution_backward_weights::desc(
|
|
||||||
convolution_direct, conv_src_md, conv_diff_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
|
|
||||||
: convolution_backward_weights::desc(
|
|
||||||
convolution_direct, conv_src_md, conv_diff_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
|
|
||||||
|
|
||||||
auto engine = streams[0].getEngine();
|
|
||||||
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc);
|
|
||||||
auto userW_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray*>(input)->buffer());
|
|
||||||
auto userW_weights_memory = mkldnn::memory({user_diff_weights_md, engine}, gradW->buffer());
|
|
||||||
auto userW_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray*>(gradO)->buffer());
|
|
||||||
|
|
||||||
auto convW_src_memory = userW_src_memory;
|
|
||||||
streams[0].addMemory(userW_src_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(convW_prim_desc.src_primitive_desc())
|
|
||||||
!= userW_src_memory.get_primitive_desc()) {
|
|
||||||
convW_src_memory = mkldnn::memory(convW_prim_desc.src_primitive_desc());
|
|
||||||
streams[0].addMemory(convW_src_memory);
|
|
||||||
streams[0].addOperation(reorder(userW_src_memory, convW_src_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convW_weights_memory = userW_weights_memory;
|
|
||||||
streams[0].addMemory(userW_weights_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_weights_primitive_desc())
|
|
||||||
!= userW_weights_memory.get_primitive_desc()) {
|
|
||||||
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_primitive_desc());
|
|
||||||
streams[0].addMemory(convW_weights_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convW_dst_memory = userW_dst_memory;
|
|
||||||
streams[0].addMemory(userW_dst_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_dst_primitive_desc())
|
|
||||||
!= userW_dst_memory.get_primitive_desc()) {
|
|
||||||
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_primitive_desc());
|
|
||||||
streams[0].addMemory(convW_dst_memory);
|
|
||||||
streams[0].addOperation(reorder(userW_dst_memory, convW_dst_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradB != nullptr) {
|
|
||||||
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_primitive_desc(), gradB->buffer());
|
|
||||||
streams[0].addMemory(convW_bias_memory);
|
|
||||||
streams[0].addOperation(convolution_backward_weights(convW_prim_desc, convW_src_memory, convW_dst_memory, convW_weights_memory, convW_bias_memory));
|
|
||||||
} else {
|
|
||||||
streams[0].addOperation(convolution_backward_weights(convW_prim_desc, convW_src_memory, convW_dst_memory, convW_weights_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_weights_primitive_desc())
|
|
||||||
!= userW_weights_memory.get_primitive_desc()) {
|
|
||||||
streams[0].addOperation(reorder(convW_weights_memory, userW_weights_memory));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradI != nullptr) {
|
|
||||||
auto convI_desc =
|
|
||||||
convolution_backward_data::desc(
|
|
||||||
convolution_direct, conv_diff_src_md, conv_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
|
|
||||||
|
|
||||||
auto engine = streams[1].getEngine();
|
|
||||||
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc);
|
|
||||||
auto userI_src_memory = mkldnn::memory({user_diff_src_md, engine}, gradI->buffer());
|
|
||||||
auto userI_weights_memory = mkldnn::memory({user_weights_md, engine}, const_cast<NDArray*>(weights)->buffer());
|
|
||||||
auto userI_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray*>(gradO)->buffer());
|
|
||||||
|
|
||||||
auto convI_src_memory = userI_src_memory;
|
|
||||||
streams[1].addMemory(userI_src_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_src_primitive_desc())
|
|
||||||
!= userI_src_memory.get_primitive_desc()) {
|
|
||||||
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_primitive_desc());
|
|
||||||
streams[1].addMemory(convI_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convI_weights_memory = userI_weights_memory;
|
|
||||||
streams[1].addMemory(userI_weights_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(convI_prim_desc.weights_primitive_desc())
|
|
||||||
!= userI_weights_memory.get_primitive_desc()) {
|
|
||||||
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_primitive_desc());
|
|
||||||
streams[1].addMemory(convI_weights_memory);
|
|
||||||
streams[1].addOperation(reorder(userI_weights_memory, convI_weights_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convI_dst_memory = userI_dst_memory;
|
|
||||||
streams[1].addMemory(userI_dst_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_dst_primitive_desc())
|
|
||||||
!= userI_dst_memory.get_primitive_desc()) {
|
|
||||||
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_primitive_desc());
|
|
||||||
streams[1].addMemory(convI_dst_memory);
|
|
||||||
streams[1].addOperation(reorder(userI_dst_memory, convI_dst_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
streams[1].addOperation(convolution_backward_data(convI_prim_desc, convI_dst_memory, convI_weights_memory, convI_src_memory));
|
|
||||||
|
|
||||||
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_src_primitive_desc())
|
|
||||||
!= userI_src_memory.get_primitive_desc()) {
|
|
||||||
streams[1].addOperation(reorder(convI_src_memory, userI_src_memory));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradW != nullptr) {
|
|
||||||
streams[0].submitAndWait();
|
|
||||||
}
|
|
||||||
if (gradI != nullptr) {
|
|
||||||
streams[1].submitAndWait();
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
nd4j_debug("MKL-DNN is not used for conv3dnew_bp!\n", 0);
|
nd4j_debug("MKL-DNN is not used for conv3dnew_bp!\n", 0);
|
||||||
|
|
||||||
std::vector<int> gradOaxesForDot;
|
std::vector<int> gradOaxesForDot;
|
|
@ -41,7 +41,6 @@ namespace nd4j {
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn: Input rank of 4 expected, but got %i instead", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn: Input rank of 4 expected, but got %i instead", input->rankOf());
|
||||||
|
|
||||||
// FIXME: double?
|
|
||||||
double alpha = T_ARG(1);
|
double alpha = T_ARG(1);
|
||||||
double beta = T_ARG(2);
|
double beta = T_ARG(2);
|
||||||
double bias = T_ARG(0);
|
double bias = T_ARG(0);
|
||||||
|
|
|
@ -24,9 +24,6 @@
|
||||||
#include <NDArray.h>
|
#include <NDArray.h>
|
||||||
#include <graph/Context.h>
|
#include <graph/Context.h>
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
#include <helpers/MKLDNNStream.h>
|
|
||||||
#endif
|
|
||||||
#include <execution/LaunchContext.h>
|
#include <execution/LaunchContext.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
@ -197,44 +194,6 @@ namespace nd4j {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
static void getMKLDNNMemoryDescConv2d(
|
|
||||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
|
|
||||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
|
||||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
|
||||||
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
|
||||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
|
||||||
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
|
||||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
|
|
||||||
|
|
||||||
static void getMKLDNNMemoryDescConv3d(
|
|
||||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
|
|
||||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
|
||||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
|
||||||
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
|
||||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
|
||||||
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
|
||||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
|
|
||||||
|
|
||||||
static void getMKLDNNMemoryDescPool2d(
|
|
||||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
|
||||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
|
||||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
|
||||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
|
||||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
|
|
||||||
|
|
||||||
static void getMKLDNNMemoryDescPool3d(
|
|
||||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
|
||||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
|
||||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
|
||||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
|
||||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||||
|
|
||||||
// static void conv2d(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
|
// static void conv2d(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
|
||||||
|
|
|
@ -28,121 +28,6 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
using namespace mkldnn;
|
|
||||||
|
|
||||||
void ConvolutionUtils::getMKLDNNMemoryDescPool2d(
|
|
||||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
|
||||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
|
||||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
|
||||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
|
||||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
|
|
||||||
mkldnn::memory::dims pool_src_tz = { bS, iC, iH, iW };
|
|
||||||
mkldnn::memory::dims pool_dst_tz = { bS, oC, oH, oW };
|
|
||||||
|
|
||||||
pool_strides = { sH, sW };
|
|
||||||
pool_kernel = { kH, kW };
|
|
||||||
pool_padding = { pH, pW };
|
|
||||||
pool_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
|
||||||
(oW - 1) * sW - iW + kW - pW };
|
|
||||||
|
|
||||||
algorithm = poolingMode == 0 ? pooling_max
|
|
||||||
: extraParam0 == 0 ? pooling_avg_exclude_padding
|
|
||||||
: pooling_avg_include_padding;
|
|
||||||
auto type = mkldnn::memory::data_type::f32;
|
|
||||||
auto format = isNCHW ? mkldnn::memory::format::nchw : mkldnn::memory::format::nhwc;
|
|
||||||
auto supposed_to_be_any_format = mkldnn::memory::format::nChw8c; // doesn't work with "any"
|
|
||||||
|
|
||||||
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
|
||||||
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
|
||||||
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCHW ? 0 : 0];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCHW ? 1 : 3];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCHW ? 2 : 1];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCHW ? 3 : 2];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
|
||||||
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
|
||||||
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
|
||||||
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
|
|
||||||
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ConvolutionUtils::getMKLDNNMemoryDescPool3d(
|
|
||||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
|
||||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
|
||||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
|
||||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
|
||||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
|
|
||||||
mkldnn::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
|
|
||||||
mkldnn::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
|
|
||||||
|
|
||||||
pool_strides = { sD, sH, sW };
|
|
||||||
pool_kernel = { kD, kH, kW };
|
|
||||||
pool_padding = { pD, pH, pW };
|
|
||||||
pool_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
|
||||||
(oH - 1) * sH - iH + kH - pH,
|
|
||||||
(oW - 1) * sW - iW + kW - pW };
|
|
||||||
|
|
||||||
algorithm = poolingMode == 0 ? pooling_max
|
|
||||||
: extraParam0 == 0 ? pooling_avg_exclude_padding
|
|
||||||
: pooling_avg_include_padding;
|
|
||||||
auto type = mkldnn::memory::data_type::f32;
|
|
||||||
auto format = isNCDHW ? mkldnn::memory::format::ncdhw : mkldnn::memory::format::ndhwc;
|
|
||||||
auto supposed_to_be_any_format = mkldnn::memory::format::nCdhw8c; // doesn't work with "any"
|
|
||||||
|
|
||||||
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
|
||||||
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
|
||||||
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
|
||||||
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
|
||||||
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
|
||||||
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
|
|
||||||
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
// [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||||
|
@ -348,174 +233,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescPool3d(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
using namespace mkldnn;
|
|
||||||
|
|
||||||
void ConvolutionUtils::getMKLDNNMemoryDescConv2d(
|
|
||||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
|
|
||||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
|
||||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
|
||||||
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
|
||||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
|
||||||
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
|
||||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r) {
|
|
||||||
mkldnn::memory::dims conv_src_tz = { bS, iC, iH, iW };
|
|
||||||
mkldnn::memory::dims conv_weights_tz = { oC, iC, kH, kW };
|
|
||||||
mkldnn::memory::dims conv_bias_tz = { oC };
|
|
||||||
mkldnn::memory::dims conv_dst_tz = { bS, oC, oH, oW };
|
|
||||||
|
|
||||||
conv_strides = { sH, sW };
|
|
||||||
conv_padding = { pH, pW };
|
|
||||||
conv_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
|
||||||
(oW - 1) * sW - iW + kW - pW };
|
|
||||||
|
|
||||||
auto type = mkldnn::memory::data_type::f32;
|
|
||||||
auto format = isNCHW ? mkldnn::memory::format::nchw : mkldnn::memory::format::nhwc;
|
|
||||||
auto formatw = mkldnn::memory::format::hwio;
|
|
||||||
|
|
||||||
if (src != nullptr && conv_src_md != nullptr) {
|
|
||||||
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format::any);
|
|
||||||
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
|
||||||
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCHW ? 0 : 0];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCHW ? 1 : 3];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCHW ? 2 : 1];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCHW ? 3 : 2];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
|
||||||
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format::any);
|
|
||||||
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
|
||||||
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (weights != nullptr && conv_weights_md != nullptr) {
|
|
||||||
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format::any);
|
|
||||||
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
|
||||||
user_weights_md->data.format = mkldnn_blocked; // overrides "formatw = hwio"
|
|
||||||
user_weights_md->data.layout_desc.blocking.strides[0][0] = weights->stridesOf()[3];
|
|
||||||
user_weights_md->data.layout_desc.blocking.strides[0][1] = weights->stridesOf()[2];
|
|
||||||
user_weights_md->data.layout_desc.blocking.strides[0][2] = weights->stridesOf()[0];
|
|
||||||
user_weights_md->data.layout_desc.blocking.strides[0][3] = weights->stridesOf()[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
|
||||||
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format::any);
|
|
||||||
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
|
||||||
user_diff_weights_md->data.format = mkldnn_blocked; // overrides "formatw = hwio"
|
|
||||||
user_diff_weights_md->data.layout_desc.blocking.strides[0][0] = diff_weights->stridesOf()[3];
|
|
||||||
user_diff_weights_md->data.layout_desc.blocking.strides[0][1] = diff_weights->stridesOf()[2];
|
|
||||||
user_diff_weights_md->data.layout_desc.blocking.strides[0][2] = diff_weights->stridesOf()[0];
|
|
||||||
user_diff_weights_md->data.layout_desc.blocking.strides[0][3] = diff_weights->stridesOf()[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bias != nullptr && conv_bias_md != nullptr) {
|
|
||||||
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format::any);
|
|
||||||
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format::x);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dst != nullptr && conv_dst_md != nullptr) {
|
|
||||||
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format::any);
|
|
||||||
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format);
|
|
||||||
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
|
||||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
|
|
||||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
|
||||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
|
||||||
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
|
||||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
|
||||||
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
|
||||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r) {
|
|
||||||
mkldnn::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
|
|
||||||
mkldnn::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
|
|
||||||
mkldnn::memory::dims conv_bias_tz = { oC };
|
|
||||||
mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
|
|
||||||
|
|
||||||
conv_strides = { sD, sH, sW };
|
|
||||||
conv_padding = { pD, pH, pW };
|
|
||||||
conv_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
|
||||||
(oH - 1) * sH - iH + kH - pH,
|
|
||||||
(oW - 1) * sW - iW + kW - pW };
|
|
||||||
|
|
||||||
auto type = mkldnn::memory::data_type::f32;
|
|
||||||
auto format = isNCDHW ? mkldnn::memory::format::ncdhw : mkldnn::memory::format::ndhwc;
|
|
||||||
auto formatw = mkldnn::memory::format::dhwio;
|
|
||||||
|
|
||||||
if (src != nullptr && conv_src_md != nullptr) {
|
|
||||||
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format::any);
|
|
||||||
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
|
||||||
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
|
||||||
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format::any);
|
|
||||||
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
|
||||||
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (weights != nullptr && conv_weights_md != nullptr) {
|
|
||||||
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format::any);
|
|
||||||
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
|
||||||
user_weights_md->data.format = mkldnn_blocked; // overrides "formatw = dhwio"
|
|
||||||
user_weights_md->data.layout_desc.blocking.strides[0][0] = weights->stridesOf()[4];
|
|
||||||
user_weights_md->data.layout_desc.blocking.strides[0][1] = weights->stridesOf()[3];
|
|
||||||
user_weights_md->data.layout_desc.blocking.strides[0][2] = weights->stridesOf()[0];
|
|
||||||
user_weights_md->data.layout_desc.blocking.strides[0][3] = weights->stridesOf()[1];
|
|
||||||
user_weights_md->data.layout_desc.blocking.strides[0][4] = weights->stridesOf()[2];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
|
||||||
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format::any);
|
|
||||||
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
|
||||||
user_diff_weights_md->data.format = mkldnn_blocked; // overrides "formatw = dhwio"
|
|
||||||
user_diff_weights_md->data.layout_desc.blocking.strides[0][0] = diff_weights->stridesOf()[4];
|
|
||||||
user_diff_weights_md->data.layout_desc.blocking.strides[0][1] = diff_weights->stridesOf()[3];
|
|
||||||
user_diff_weights_md->data.layout_desc.blocking.strides[0][2] = diff_weights->stridesOf()[0];
|
|
||||||
user_diff_weights_md->data.layout_desc.blocking.strides[0][3] = diff_weights->stridesOf()[1];
|
|
||||||
user_diff_weights_md->data.layout_desc.blocking.strides[0][4] = diff_weights->stridesOf()[2];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bias != nullptr && conv_bias_md != nullptr) {
|
|
||||||
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format::any);
|
|
||||||
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format::x);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dst != nullptr && conv_dst_md != nullptr) {
|
|
||||||
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format::any);
|
|
||||||
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format);
|
|
||||||
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||||
|
@ -543,83 +260,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // SAME
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<X, Y>()) {
|
|
||||||
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
||||||
if (streams.empty()) {
|
|
||||||
streams.push_back(MKLDNNStream("conv2d"));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (streams[0].checkAndReset({input, weights, bias}, {output}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW})) {
|
|
||||||
mkldnn_memory_desc_t empty;
|
|
||||||
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
|
||||||
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
|
||||||
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
|
|
||||||
|
|
||||||
ConvolutionUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, bias, output,
|
|
||||||
&conv_src_md, nullptr, &conv_weights_md, nullptr, &conv_bias_md, &conv_dst_md,
|
|
||||||
&user_src_md, nullptr, &user_weights_md, nullptr, &user_bias_md, &user_dst_md,
|
|
||||||
conv_strides, conv_padding, conv_padding_r);
|
|
||||||
|
|
||||||
auto conv_desc = bias != nullptr
|
|
||||||
? convolution_forward::desc(prop_kind::forward,
|
|
||||||
convolution_direct, conv_src_md, conv_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
|
|
||||||
: convolution_forward::desc(prop_kind::forward,
|
|
||||||
convolution_direct, conv_src_md, conv_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
|
|
||||||
|
|
||||||
auto engine = streams[0].getEngine();
|
|
||||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
|
|
||||||
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray*>(input)->buffer());
|
|
||||||
auto user_weights_memory = mkldnn::memory({user_weights_md, engine}, const_cast<NDArray*>(weights)->buffer());
|
|
||||||
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output->buffer());
|
|
||||||
|
|
||||||
auto conv_src_memory = user_src_memory;
|
|
||||||
streams[0].addMemory(user_src_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(conv_prim_desc.src_primitive_desc())
|
|
||||||
!= user_src_memory.get_primitive_desc()) {
|
|
||||||
conv_src_memory = mkldnn::memory(conv_prim_desc.src_primitive_desc());
|
|
||||||
streams[0].addMemory(conv_src_memory);
|
|
||||||
streams[0].addOperation(reorder(user_src_memory, conv_src_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto conv_weights_memory = user_weights_memory;
|
|
||||||
streams[0].addMemory(user_weights_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(conv_prim_desc.weights_primitive_desc())
|
|
||||||
!= user_weights_memory.get_primitive_desc()) {
|
|
||||||
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_primitive_desc());
|
|
||||||
streams[0].addMemory(conv_weights_memory);
|
|
||||||
streams[0].addOperation(reorder(user_weights_memory, conv_weights_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto conv_dst_memory = user_dst_memory;
|
|
||||||
streams[0].addMemory(user_dst_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(conv_prim_desc.dst_primitive_desc())
|
|
||||||
!= user_dst_memory.get_primitive_desc()) {
|
|
||||||
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_primitive_desc());
|
|
||||||
streams[0].addMemory(conv_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bias != nullptr) {
|
|
||||||
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_primitive_desc(), const_cast<NDArray*>(bias)->buffer());
|
|
||||||
streams[0].addMemory(conv_bias_memory);
|
|
||||||
streams[0].addOperation(convolution_forward(conv_prim_desc, conv_src_memory, conv_weights_memory, conv_bias_memory, conv_dst_memory));
|
|
||||||
} else {
|
|
||||||
streams[0].addOperation(convolution_forward(conv_prim_desc, conv_src_memory, conv_weights_memory, conv_dst_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mkldnn::memory::primitive_desc(conv_prim_desc.dst_primitive_desc())
|
|
||||||
!= user_dst_memory.get_primitive_desc()) {
|
|
||||||
streams[0].addOperation(reorder(conv_dst_memory, user_dst_memory));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
streams[0].submitAndWait();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
nd4j_debug("MKL-DNN is not used for conv2d!\n", 0);
|
nd4j_debug("MKL-DNN is not used for conv2d!\n", 0);
|
||||||
|
|
||||||
std::vector<int> permutForOutput;
|
std::vector<int> permutForOutput;
|
||||||
|
@ -686,151 +326,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // SAME
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<X, Y>()) {
|
|
||||||
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
||||||
if (streams.empty()) {
|
|
||||||
streams.push_back(MKLDNNStream("conv2d_bp_weights"));
|
|
||||||
streams.push_back(MKLDNNStream("conv2d_bp_data"));
|
|
||||||
}
|
|
||||||
|
|
||||||
bool resetW = streams[0].checkAndReset({input, weights, bias, gradO}, {gradI, gradW, gradB}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW});
|
|
||||||
bool resetI = streams[1].checkAndReset({input, weights, bias, gradO}, {gradI, gradW, gradB}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW});
|
|
||||||
if (resetW || resetI) {
|
|
||||||
mkldnn_memory_desc_t empty;
|
|
||||||
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
|
||||||
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
|
||||||
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
|
||||||
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
|
||||||
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
|
|
||||||
|
|
||||||
ConvolutionUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, gradB, gradO,
|
|
||||||
&conv_src_md, &conv_diff_src_md, &conv_weights_md, &conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
|
|
||||||
&user_src_md, &user_diff_src_md, &user_weights_md, &user_diff_weights_md, &user_bias_md, &user_dst_md,
|
|
||||||
conv_strides, conv_padding, conv_padding_r);
|
|
||||||
|
|
||||||
auto conv_desc = gradB != nullptr
|
|
||||||
? convolution_forward::desc(prop_kind::forward,
|
|
||||||
convolution_direct, conv_src_md, conv_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
|
|
||||||
: convolution_forward::desc(prop_kind::forward,
|
|
||||||
convolution_direct, conv_src_md, conv_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
|
|
||||||
|
|
||||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, streams[0].getEngine());
|
|
||||||
|
|
||||||
if (gradW != nullptr) {
|
|
||||||
auto convW_desc = gradB != nullptr
|
|
||||||
? convolution_backward_weights::desc(
|
|
||||||
convolution_direct, conv_src_md, conv_diff_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
|
|
||||||
: convolution_backward_weights::desc(
|
|
||||||
convolution_direct, conv_src_md, conv_diff_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
|
|
||||||
|
|
||||||
auto engine = streams[0].getEngine();
|
|
||||||
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc);
|
|
||||||
auto userW_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray*>(input)->buffer());
|
|
||||||
auto userW_weights_memory = mkldnn::memory({user_diff_weights_md, engine}, gradW->buffer());
|
|
||||||
auto userW_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray*>(gradO)->buffer());
|
|
||||||
|
|
||||||
auto convW_src_memory = userW_src_memory;
|
|
||||||
streams[0].addMemory(userW_src_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(convW_prim_desc.src_primitive_desc())
|
|
||||||
!= userW_src_memory.get_primitive_desc()) {
|
|
||||||
convW_src_memory = mkldnn::memory(convW_prim_desc.src_primitive_desc());
|
|
||||||
streams[0].addMemory(convW_src_memory);
|
|
||||||
streams[0].addOperation(reorder(userW_src_memory, convW_src_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convW_weights_memory = userW_weights_memory;
|
|
||||||
streams[0].addMemory(userW_weights_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_weights_primitive_desc())
|
|
||||||
!= userW_weights_memory.get_primitive_desc()) {
|
|
||||||
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_primitive_desc());
|
|
||||||
streams[0].addMemory(convW_weights_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convW_dst_memory = userW_dst_memory;
|
|
||||||
streams[0].addMemory(userW_dst_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_dst_primitive_desc())
|
|
||||||
!= userW_dst_memory.get_primitive_desc()) {
|
|
||||||
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_primitive_desc());
|
|
||||||
streams[0].addMemory(convW_dst_memory);
|
|
||||||
streams[0].addOperation(reorder(userW_dst_memory, convW_dst_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradB != nullptr) {
|
|
||||||
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_primitive_desc(), gradB->buffer());
|
|
||||||
streams[0].addMemory(convW_bias_memory);
|
|
||||||
streams[0].addOperation(convolution_backward_weights(convW_prim_desc, convW_src_memory, convW_dst_memory, convW_weights_memory, convW_bias_memory));
|
|
||||||
} else {
|
|
||||||
streams[0].addOperation(convolution_backward_weights(convW_prim_desc, convW_src_memory, convW_dst_memory, convW_weights_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_weights_primitive_desc())
|
|
||||||
!= userW_weights_memory.get_primitive_desc()) {
|
|
||||||
streams[0].addOperation(reorder(convW_weights_memory, userW_weights_memory));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradI != nullptr) {
|
|
||||||
auto convI_desc =
|
|
||||||
convolution_backward_data::desc(
|
|
||||||
convolution_direct, conv_diff_src_md, conv_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
|
|
||||||
|
|
||||||
auto engine = streams[1].getEngine();
|
|
||||||
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc);
|
|
||||||
auto userI_src_memory = mkldnn::memory({user_diff_src_md, engine}, gradI->buffer());
|
|
||||||
auto userI_weights_memory = mkldnn::memory({user_weights_md, engine}, const_cast<NDArray*>(weights)->buffer());
|
|
||||||
auto userI_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray*>(gradO)->buffer());
|
|
||||||
|
|
||||||
auto convI_src_memory = userI_src_memory;
|
|
||||||
streams[1].addMemory(userI_src_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_src_primitive_desc())
|
|
||||||
!= userI_src_memory.get_primitive_desc()) {
|
|
||||||
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_primitive_desc());
|
|
||||||
streams[1].addMemory(convI_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convI_weights_memory = userI_weights_memory;
|
|
||||||
streams[1].addMemory(userI_weights_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(convI_prim_desc.weights_primitive_desc())
|
|
||||||
!= userI_weights_memory.get_primitive_desc()) {
|
|
||||||
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_primitive_desc());
|
|
||||||
streams[1].addMemory(convI_weights_memory);
|
|
||||||
streams[1].addOperation(reorder(userI_weights_memory, convI_weights_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convI_dst_memory = userI_dst_memory;
|
|
||||||
streams[1].addMemory(userI_dst_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_dst_primitive_desc())
|
|
||||||
!= userI_dst_memory.get_primitive_desc()) {
|
|
||||||
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_primitive_desc());
|
|
||||||
streams[1].addMemory(convI_dst_memory);
|
|
||||||
streams[1].addOperation(reorder(userI_dst_memory, convI_dst_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
streams[1].addOperation(convolution_backward_data(convI_prim_desc, convI_dst_memory, convI_weights_memory, convI_src_memory));
|
|
||||||
|
|
||||||
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_src_primitive_desc())
|
|
||||||
!= userI_src_memory.get_primitive_desc()) {
|
|
||||||
streams[1].addOperation(reorder(convI_src_memory, userI_src_memory));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradW != nullptr) {
|
|
||||||
streams[0].submitAndWait();
|
|
||||||
}
|
|
||||||
if (gradI != nullptr) {
|
|
||||||
streams[1].submitAndWait();
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0);
|
nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0);
|
||||||
|
|
||||||
std::vector<int> gradOaxesForDot;
|
std::vector<int> gradOaxesForDot;
|
||||||
|
@ -1268,62 +763,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
||||||
const int oH = output.sizeAt(2);
|
const int oH = output.sizeAt(2);
|
||||||
const int oW = output.sizeAt(3);
|
const int oW = output.sizeAt(3);
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<T, T>()) {
|
|
||||||
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
||||||
if (streams.empty()) {
|
|
||||||
streams.push_back(MKLDNNStream("pooling2d"));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (streams[0].checkAndReset({&input}, {&output}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0})) {
|
|
||||||
mkldnn_memory_desc_t empty;
|
|
||||||
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
|
||||||
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
|
|
||||||
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
mkldnn::algorithm algorithm;
|
|
||||||
|
|
||||||
ConvolutionUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, true,
|
|
||||||
bS, iC, iH, iW, oC, oH, oW, &input, nullptr, &output, algorithm,
|
|
||||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, &user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, pool_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
|
|
||||||
|
|
||||||
auto engine = streams[0].getEngine();
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray&>(input).buffer());
|
|
||||||
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output.buffer());
|
|
||||||
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
streams[0].addMemory(user_src_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(pool_prim_desc.src_primitive_desc())
|
|
||||||
!= user_src_memory.get_primitive_desc()) {
|
|
||||||
pool_src_memory = mkldnn::memory(pool_prim_desc.src_primitive_desc());
|
|
||||||
streams[0].addMemory(pool_src_memory);
|
|
||||||
streams[0].addOperation(reorder(user_src_memory, pool_src_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pool_dst_memory = user_dst_memory;
|
|
||||||
streams[0].addMemory(user_dst_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(pool_prim_desc.dst_primitive_desc())
|
|
||||||
!= user_dst_memory.get_primitive_desc()) {
|
|
||||||
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_primitive_desc());
|
|
||||||
streams[0].addMemory(pool_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
streams[0].addOperation(pooling_forward(pool_prim_desc, pool_src_memory, pool_dst_memory));
|
|
||||||
|
|
||||||
if (mkldnn::memory::primitive_desc(pool_prim_desc.dst_primitive_desc())
|
|
||||||
!= user_dst_memory.get_primitive_desc()) {
|
|
||||||
streams[0].addOperation(reorder(pool_dst_memory, user_dst_memory));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
streams[0].submitAndWait();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
nd4j_debug("MKL-DNN is not used for pooling2d!\n", 0);
|
nd4j_debug("MKL-DNN is not used for pooling2d!\n", 0);
|
||||||
|
|
||||||
const Nd4jLong iStride0 = input.stridesOf()[0];
|
const Nd4jLong iStride0 = input.stridesOf()[0];
|
||||||
|
@ -1504,62 +943,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
||||||
const int oH = output.sizeAt(3);
|
const int oH = output.sizeAt(3);
|
||||||
const int oW = output.sizeAt(4);
|
const int oW = output.sizeAt(4);
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<T, T>()) {
|
|
||||||
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
||||||
if (streams.empty()) {
|
|
||||||
streams.push_back(MKLDNNStream("pooling3d"));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (streams[0].checkAndReset({&input}, {&output}, {}, {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0})) {
|
|
||||||
mkldnn_memory_desc_t empty;
|
|
||||||
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
|
||||||
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
|
|
||||||
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
mkldnn::algorithm algorithm;
|
|
||||||
|
|
||||||
ConvolutionUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0, true,
|
|
||||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, &input, nullptr, &output, algorithm,
|
|
||||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, &user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, pool_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
|
|
||||||
|
|
||||||
auto engine = streams[0].getEngine();
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray&>(input).buffer());
|
|
||||||
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output.buffer());
|
|
||||||
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
streams[0].addMemory(user_src_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(pool_prim_desc.src_primitive_desc())
|
|
||||||
!= user_src_memory.get_primitive_desc()) {
|
|
||||||
pool_src_memory = mkldnn::memory(pool_prim_desc.src_primitive_desc());
|
|
||||||
streams[0].addMemory(pool_src_memory);
|
|
||||||
streams[0].addOperation(reorder(user_src_memory, pool_src_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pool_dst_memory = user_dst_memory;
|
|
||||||
streams[0].addMemory(user_dst_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(pool_prim_desc.dst_primitive_desc())
|
|
||||||
!= user_dst_memory.get_primitive_desc()) {
|
|
||||||
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_primitive_desc());
|
|
||||||
streams[0].addMemory(pool_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
streams[0].addOperation(pooling_forward(pool_prim_desc, pool_src_memory, pool_dst_memory));
|
|
||||||
|
|
||||||
if (mkldnn::memory::primitive_desc(pool_prim_desc.dst_primitive_desc())
|
|
||||||
!= user_dst_memory.get_primitive_desc()) {
|
|
||||||
streams[0].addOperation(reorder(pool_dst_memory, user_dst_memory));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
streams[0].submitAndWait();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
nd4j_debug("MKL-DNN is not used for pooling3d!\n", 0);
|
nd4j_debug("MKL-DNN is not used for pooling3d!\n", 0);
|
||||||
|
|
||||||
const Nd4jLong iStride0 = input.stridesOf()[0];
|
const Nd4jLong iStride0 = input.stridesOf()[0];
|
||||||
|
@ -1776,91 +1159,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
||||||
const int oH = gradO.sizeAt(2);
|
const int oH = gradO.sizeAt(2);
|
||||||
const int oW = gradO.sizeAt(3);
|
const int oW = gradO.sizeAt(3);
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<T, T>()) {
|
|
||||||
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
||||||
if (streams.empty()) {
|
|
||||||
streams.push_back(MKLDNNStream("pooling2d_bp"));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (streams[0].checkAndReset({&input, &gradO}, {&gradI}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0})) {
|
|
||||||
mkldnn_memory_desc_t empty;
|
|
||||||
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
|
||||||
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
|
||||||
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
mkldnn::algorithm algorithm;
|
|
||||||
|
|
||||||
ConvolutionUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, true,
|
|
||||||
bS, iC, iH, iW, oC, oH, oW, &input, &gradI, &gradO, algorithm,
|
|
||||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, &user_diff_src_md, &user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
// input is sometimes null, so we can't rely on pool_src_md being valid
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
|
||||||
const_cast<NDArray&>(input).buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
|
||||||
pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
|
|
||||||
|
|
||||||
auto engine = streams[0].getEngine();
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
|
|
||||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
|
|
||||||
|
|
||||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
|
||||||
auto userB_src_memory = mkldnn::memory({user_src_md, engine}, gradI.buffer());
|
|
||||||
auto userB_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray&>(gradO).buffer());
|
|
||||||
|
|
||||||
auto poolB_src_memory = userB_src_memory;
|
|
||||||
streams[0].addMemory(userB_src_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_src_primitive_desc())
|
|
||||||
!= userB_src_memory.get_primitive_desc()) {
|
|
||||||
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_primitive_desc());
|
|
||||||
streams[0].addMemory(poolB_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto poolB_dst_memory = userB_dst_memory;
|
|
||||||
streams[0].addMemory(userB_dst_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_dst_primitive_desc())
|
|
||||||
!= userB_dst_memory.get_primitive_desc()) {
|
|
||||||
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_primitive_desc());
|
|
||||||
streams[0].addMemory(poolB_dst_memory);
|
|
||||||
streams[0].addOperation(reorder(userB_dst_memory, poolB_dst_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (algorithm == mkldnn::pooling_max) {
|
|
||||||
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray&>(input).buffer());
|
|
||||||
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
streams[0].addMemory(user_src_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(pool_prim_desc.src_primitive_desc())
|
|
||||||
!= user_src_memory.get_primitive_desc()) {
|
|
||||||
pool_src_memory = mkldnn::memory(pool_prim_desc.src_primitive_desc());
|
|
||||||
streams[0].addMemory(pool_src_memory);
|
|
||||||
streams[0].addOperation(reorder(user_src_memory, pool_src_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_primitive_desc());
|
|
||||||
streams[0].addMemory(pool_dst_memory);
|
|
||||||
|
|
||||||
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_primitive_desc());
|
|
||||||
streams[0].addMemory(pool_workspace_memory);
|
|
||||||
|
|
||||||
streams[0].addOperation(pooling_forward(pool_prim_desc, pool_src_memory, pool_dst_memory, pool_workspace_memory));
|
|
||||||
streams[0].addOperation(pooling_backward(poolB_prim_desc, poolB_dst_memory, pool_workspace_memory, poolB_src_memory));
|
|
||||||
} else {
|
|
||||||
streams[0].addOperation(pooling_backward(poolB_prim_desc, poolB_dst_memory, poolB_src_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_src_primitive_desc())
|
|
||||||
!= userB_src_memory.get_primitive_desc()) {
|
|
||||||
streams[0].addOperation(reorder(poolB_src_memory, userB_src_memory));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
streams[0].submitAndWait();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
nd4j_debug("MKL-DNN is not used for pooling2d_bp!\n", 0);
|
nd4j_debug("MKL-DNN is not used for pooling2d_bp!\n", 0);
|
||||||
|
|
||||||
const Nd4jLong iStride0 = input.stridesOf()[0];
|
const Nd4jLong iStride0 = input.stridesOf()[0];
|
||||||
|
@ -2099,94 +1397,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
||||||
const int oH = gradO.sizeAt(3);
|
const int oH = gradO.sizeAt(3);
|
||||||
const int oW = gradO.sizeAt(4);
|
const int oW = gradO.sizeAt(4);
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<T, T>()) {
|
|
||||||
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
||||||
if (streams.empty()) {
|
|
||||||
streams.push_back(MKLDNNStream("pooling3d_bp"));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (streams[0].checkAndReset({&input, &gradO}, {&gradI}, {}, {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0})) {
|
|
||||||
mkldnn_memory_desc_t empty;
|
|
||||||
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
|
||||||
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
|
||||||
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
mkldnn::algorithm algorithm;
|
|
||||||
|
|
||||||
ConvolutionUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0, true,
|
|
||||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, &input, &gradI, &gradO, algorithm,
|
|
||||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, &user_diff_src_md, &user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
// input is sometimes null, so we can't rely on pool_src_md being valid
|
|
||||||
if (const_cast<NDArray&>(input).buffer() == nullptr) {
|
|
||||||
pool_src_md = pool_diff_src_md;
|
|
||||||
user_src_md = user_diff_src_md;
|
|
||||||
}
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md,
|
|
||||||
pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
|
|
||||||
|
|
||||||
auto engine = streams[0].getEngine();
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
|
|
||||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
|
|
||||||
|
|
||||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
|
||||||
auto userB_src_memory = mkldnn::memory({user_diff_src_md, engine}, gradI.buffer());
|
|
||||||
auto userB_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray&>(gradO).buffer());
|
|
||||||
|
|
||||||
auto poolB_src_memory = userB_src_memory;
|
|
||||||
streams[0].addMemory(userB_src_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_src_primitive_desc())
|
|
||||||
!= userB_src_memory.get_primitive_desc()) {
|
|
||||||
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_primitive_desc());
|
|
||||||
streams[0].addMemory(poolB_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto poolB_dst_memory = userB_dst_memory;
|
|
||||||
streams[0].addMemory(userB_dst_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_dst_primitive_desc())
|
|
||||||
!= userB_dst_memory.get_primitive_desc()) {
|
|
||||||
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_primitive_desc());
|
|
||||||
streams[0].addMemory(poolB_dst_memory);
|
|
||||||
streams[0].addOperation(reorder(userB_dst_memory, poolB_dst_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (algorithm == mkldnn::pooling_max) {
|
|
||||||
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray&>(input).buffer());
|
|
||||||
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
streams[0].addMemory(user_src_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(pool_prim_desc.src_primitive_desc())
|
|
||||||
!= user_src_memory.get_primitive_desc()) {
|
|
||||||
pool_src_memory = mkldnn::memory(pool_prim_desc.src_primitive_desc());
|
|
||||||
streams[0].addMemory(pool_src_memory);
|
|
||||||
streams[0].addOperation(reorder(user_src_memory, pool_src_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_primitive_desc());
|
|
||||||
streams[0].addMemory(pool_dst_memory);
|
|
||||||
|
|
||||||
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_primitive_desc());
|
|
||||||
streams[0].addMemory(pool_workspace_memory);
|
|
||||||
|
|
||||||
streams[0].addOperation(pooling_forward(pool_prim_desc, pool_src_memory, pool_dst_memory, pool_workspace_memory));
|
|
||||||
streams[0].addOperation(pooling_backward(poolB_prim_desc, poolB_dst_memory, pool_workspace_memory, poolB_src_memory));
|
|
||||||
} else {
|
|
||||||
streams[0].addOperation(pooling_backward(poolB_prim_desc, poolB_dst_memory, poolB_src_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_src_primitive_desc())
|
|
||||||
!= userB_src_memory.get_primitive_desc()) {
|
|
||||||
streams[0].addOperation(reorder(poolB_src_memory, userB_src_memory));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
streams[0].submitAndWait();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
nd4j_debug("MKL-DNN is not used for pooling3d_bp!\n", 0);
|
nd4j_debug("MKL-DNN is not used for pooling3d_bp!\n", 0);
|
||||||
|
|
||||||
const Nd4jLong iStride0 = input.stridesOf()[0];
|
const Nd4jLong iStride0 = input.stridesOf()[0];
|
||||||
|
|
|
@ -27,107 +27,9 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
using namespace mkldnn;
|
|
||||||
|
|
||||||
static void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
|
||||||
mkldnn::memory::desc* lrn_src_md, mkldnn::memory::desc* lrn_diff_src_md, mkldnn::memory::desc* lrn_dst_md,
|
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
|
|
||||||
const Nd4jLong* shape = src->getShapeInfo();
|
|
||||||
long rank = shape[0];
|
|
||||||
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
|
||||||
long dim2 = axis >= 2 ? 1 : 2;
|
|
||||||
long dim3 = axis >= 3 ? 2 : 3;
|
|
||||||
mkldnn::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
|
||||||
|
|
||||||
auto type = mkldnn::memory::data_type::f32;
|
|
||||||
auto format = axis == 1 ? mkldnn::memory::format::nchw : mkldnn::memory::format::nhwc;
|
|
||||||
auto supposed_to_be_any_format = format; // doesn't work with "any"
|
|
||||||
|
|
||||||
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
|
|
||||||
*lrn_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
|
|
||||||
user_src_md->data.format = mkldnn_blocked;
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[0];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[dim1];
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
|
||||||
user_src_md->data.layout_desc.blocking.strides[0][3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
|
|
||||||
*lrn_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
|
|
||||||
user_diff_src_md->data.format = mkldnn_blocked;
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[0];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[dim1];
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
|
||||||
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
|
|
||||||
*lrn_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
|
|
||||||
user_dst_md->data.format = mkldnn_blocked;
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[0];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[dim1];
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
|
||||||
user_dst_md->data.layout_desc.blocking.strides[0][3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static int lrnFunctor_(nd4j::graph::Context& block, NDArray* input, NDArray* output, int depth, float bias, float alpha, float beta) {
|
static int lrnFunctor_(nd4j::graph::Context& block, NDArray* input, NDArray* output, int depth, float bias, float alpha, float beta) {
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
|
||||||
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output})) {
|
|
||||||
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
|
||||||
if (streams.empty()) {
|
|
||||||
streams.push_back(MKLDNNStream("lrn"));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (streams[0].checkAndReset({input}, {output}, {(float)bias, (float)alpha, (float)beta}, {depth})) {
|
|
||||||
mkldnn_memory_desc_t empty;
|
|
||||||
mkldnn::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty);
|
|
||||||
|
|
||||||
getMKLDNNMemoryDescLrn(input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md, &user_src_md, nullptr, &user_dst_md, input->rankOf() - 1);
|
|
||||||
|
|
||||||
auto lrn_desc = lrn_forward::desc(prop_kind::forward_inference, lrn_across_channels, lrn_src_md, (2 * depth + 1), alpha * (2 * depth + 1), beta, bias);
|
|
||||||
|
|
||||||
auto engine = streams[0].getEngine();
|
|
||||||
auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine);
|
|
||||||
auto user_src_memory = mkldnn::memory({user_src_md, engine}, input->buffer());
|
|
||||||
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output->buffer());
|
|
||||||
|
|
||||||
auto lrn_src_memory = user_src_memory;
|
|
||||||
streams[0].addMemory(user_src_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(lrn_prim_desc.src_primitive_desc())
|
|
||||||
!= user_src_memory.get_primitive_desc()) {
|
|
||||||
lrn_src_memory = mkldnn::memory(lrn_prim_desc.src_primitive_desc());
|
|
||||||
streams[0].addMemory(lrn_src_memory);
|
|
||||||
streams[0].addOperation(reorder(user_src_memory, lrn_src_memory));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto lrn_dst_memory = user_dst_memory;
|
|
||||||
streams[0].addMemory(user_dst_memory);
|
|
||||||
if (mkldnn::memory::primitive_desc(lrn_prim_desc.dst_primitive_desc())
|
|
||||||
!= user_dst_memory.get_primitive_desc()) {
|
|
||||||
lrn_dst_memory = mkldnn::memory(lrn_prim_desc.dst_primitive_desc());
|
|
||||||
streams[0].addMemory(lrn_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
streams[0].addOperation(lrn_forward(lrn_prim_desc, lrn_src_memory, lrn_dst_memory));
|
|
||||||
|
|
||||||
if (mkldnn::memory::primitive_desc(lrn_prim_desc.dst_primitive_desc())
|
|
||||||
!= user_dst_memory.get_primitive_desc()) {
|
|
||||||
streams[0].addOperation(reorder(lrn_dst_memory, user_dst_memory));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
streams[0].submitAndWait();
|
|
||||||
return ND4J_STATUS_OK;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
nd4j_debug("MKL-DNN is not used for lrn!\n", 0);
|
nd4j_debug("MKL-DNN is not used for lrn!\n", 0);
|
||||||
|
|
||||||
const int rank = input->rankOf();
|
const int rank = input->rankOf();
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
#include <exceptions/graph_exception.h>
|
#include <exceptions/graph_exception.h>
|
||||||
#include <exceptions/unresolved_input_exception.h>
|
#include <exceptions/unresolved_input_exception.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -501,7 +502,22 @@ namespace nd4j {
|
||||||
prepTime = std::chrono::duration_cast<std::chrono::nanoseconds>(timeStart - timeEnter).count();
|
prepTime = std::chrono::duration_cast<std::chrono::nanoseconds>(timeStart - timeEnter).count();
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jStatus status = this->validateAndExecute(*block);
|
|
||||||
|
Nd4jStatus status;
|
||||||
|
bool hasHelper = false;
|
||||||
|
|
||||||
|
// if we have platform-specific helper for this op - invoke it
|
||||||
|
if (OpRegistrator::getInstance()->hasHelper(this->getOpHash())) {
|
||||||
|
auto helper = OpRegistrator::getInstance()->getPlatformHelper(this->getOpHash());
|
||||||
|
if (helper->isUsable(*block)) {
|
||||||
|
status = helper->invokeHelper(*block);
|
||||||
|
hasHelper = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we don't have platform-specific helper - invoke generic implementation
|
||||||
|
if (!hasHelper)
|
||||||
|
status = this->validateAndExecute(*block);
|
||||||
|
|
||||||
// optionally saving execution time
|
// optionally saving execution time
|
||||||
if (Environment::getInstance()->isProfiling()) {
|
if (Environment::getInstance()->isProfiling()) {
|
||||||
|
|
|
@ -113,8 +113,13 @@ namespace nd4j {
|
||||||
for (auto x : _uniqueD)
|
for (auto x : _uniqueD)
|
||||||
delete x;
|
delete x;
|
||||||
|
|
||||||
|
for (auto x: _uniqueH)
|
||||||
|
delete x;
|
||||||
|
|
||||||
_uniqueD.clear();
|
_uniqueD.clear();
|
||||||
|
|
||||||
|
_uniqueH.clear();
|
||||||
|
|
||||||
_declarablesD.clear();
|
_declarablesD.clear();
|
||||||
|
|
||||||
_declarablesLD.clear();
|
_declarablesLD.clear();
|
||||||
|
@ -144,6 +149,8 @@ namespace nd4j {
|
||||||
|
|
||||||
return _opsList.c_str();
|
return _opsList.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
bool OpRegistrator::registerOperation(const char* name, nd4j::ops::DeclarableOp* op) {
|
bool OpRegistrator::registerOperation(const char* name, nd4j::ops::DeclarableOp* op) {
|
||||||
std::string str(name);
|
std::string str(name);
|
||||||
std::pair<std::string, nd4j::ops::DeclarableOp*> pair(str, op);
|
std::pair<std::string, nd4j::ops::DeclarableOp*> pair(str, op);
|
||||||
|
@ -165,6 +172,19 @@ namespace nd4j {
|
||||||
return registerOperation(op->getOpName()->c_str(), op);
|
return registerOperation(op->getOpName()->c_str(), op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OpRegistrator::registerHelper(nd4j::ops::platforms::PlatformHelper* op) {
|
||||||
|
if (_helpersLH.count(op->hash()) > 0)
|
||||||
|
throw std::runtime_error("Tried to double register PlatformHelper");
|
||||||
|
|
||||||
|
_uniqueH.emplace_back(op);
|
||||||
|
|
||||||
|
std::pair<std::string, nd4j::ops::platforms::PlatformHelper*> pair(op->name(), op);
|
||||||
|
_helpersH.insert(pair);
|
||||||
|
|
||||||
|
std::pair<Nd4jLong, nd4j::ops::platforms::PlatformHelper*> pair2(op->hash(), op);
|
||||||
|
_helpersLH.insert(pair2);
|
||||||
|
}
|
||||||
|
|
||||||
nd4j::ops::DeclarableOp* OpRegistrator::getOperation(const char *name) {
|
nd4j::ops::DeclarableOp* OpRegistrator::getOperation(const char *name) {
|
||||||
std::string str(name);
|
std::string str(name);
|
||||||
return getOperation(str);
|
return getOperation(str);
|
||||||
|
@ -207,6 +227,16 @@ namespace nd4j {
|
||||||
return _declarablesD.at(name);
|
return _declarablesD.at(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nd4j::ops::platforms::PlatformHelper* OpRegistrator::getPlatformHelper(Nd4jLong hash) {
|
||||||
|
if (_helpersLH.count(hash) == 0)
|
||||||
|
throw std::runtime_error("Requested helper can't be found");
|
||||||
|
|
||||||
|
return _helpersLH[hash];
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OpRegistrator::hasHelper(Nd4jLong hash) {
|
||||||
|
return _helpersLH.count(hash) > 0;
|
||||||
|
}
|
||||||
|
|
||||||
int OpRegistrator::numberOfOperations() {
|
int OpRegistrator::numberOfOperations() {
|
||||||
return (int) _declarablesLD.size();
|
return (int) _declarablesLD.size();
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../PlatformHelper.h"
|
||||||
|
#include <graph/Variable.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
PlatformHelper::PlatformHelper(const char *name) {
|
||||||
|
// we just store name/hash of target operation
|
||||||
|
_name = std::string(name);
|
||||||
|
_hash = HashHelper::getInstance()->getLongHash(_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
nd4j::NDArray *PlatformHelper::getZ(graph::Context &ctx, int inputId) {
|
||||||
|
NDArray *z = nullptr;
|
||||||
|
|
||||||
|
if (ctx.isFastPath()) {
|
||||||
|
if (ctx.fastpath_out().size() <= inputId) {
|
||||||
|
if (ctx.isInplace()) {
|
||||||
|
z = ctx.fastpath_in()[inputId];
|
||||||
|
} else
|
||||||
|
throw std::runtime_error("fastpath_out: unresolved output array");
|
||||||
|
} else {
|
||||||
|
z = ctx.fastpath_out()[inputId];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::pair<int, int> pair(ctx.nodeId(), inputId);
|
||||||
|
|
||||||
|
if (ctx.isInplace()) {
|
||||||
|
z = ctx.variable(inputId)->getNDArray();
|
||||||
|
|
||||||
|
// hypothetically it's possible to have no variable. chances are low, but who knows. let's just create it for now
|
||||||
|
if (!ctx.getVariableSpace()->hasVariable(pair)) {
|
||||||
|
auto var = new graph::Variable();
|
||||||
|
ctx.getVariableSpace()->putVariable(pair, var);
|
||||||
|
}
|
||||||
|
|
||||||
|
// now we're saving input array as output array
|
||||||
|
auto var = ctx.getVariableSpace()->getVariable(pair);
|
||||||
|
var->markRemovable(false);
|
||||||
|
var->setNDArray(z);
|
||||||
|
} else if (!ctx.isInplace()) {
|
||||||
|
auto var = ctx.variable(pair);
|
||||||
|
if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) {
|
||||||
|
z = var->getNDArray();
|
||||||
|
} else {
|
||||||
|
nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nd4j_printf("BOOM!\n", "");
|
||||||
|
throw std::runtime_error("Boom!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return z;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PlatformHelper::name() {
|
||||||
|
return _name;
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong PlatformHelper::hash() {
|
||||||
|
return _hash;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1 @@
|
||||||
|
This folder contains platform-specific optimized implementations for custom operations
|
|
@ -0,0 +1,143 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author saudet
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
PLATFORM_IMPL(avgpool2d) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
|
||||||
|
input->rankOf());
|
||||||
|
|
||||||
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
|
auto argI = *(block.getIArguments());
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const auto kH = INT_ARG(0);
|
||||||
|
const auto kW = INT_ARG(1);
|
||||||
|
const auto sH = INT_ARG(2);
|
||||||
|
const auto sW = INT_ARG(3);
|
||||||
|
int pH = INT_ARG(4);
|
||||||
|
int pW = INT_ARG(5);
|
||||||
|
const auto dH = INT_ARG(6);
|
||||||
|
const auto dW = INT_ARG(7);
|
||||||
|
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
||||||
|
const auto extraParam0 = INT_ARG(9);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
|
||||||
|
dH, dW);
|
||||||
|
|
||||||
|
int oH = 0;
|
||||||
|
int oW = 0;
|
||||||
|
|
||||||
|
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||||
|
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||||
|
|
||||||
|
if (!isNCHW) {
|
||||||
|
input = new NDArray(
|
||||||
|
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
|
output = new NDArray(
|
||||||
|
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
|
||||||
|
if (isSameMode)
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
const int bS = input->sizeAt(0);
|
||||||
|
const int iC = input->sizeAt(1);
|
||||||
|
const int oC = output->sizeAt(1);
|
||||||
|
|
||||||
|
auto poolingMode = PoolingType::AVG_POOL;
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||||
|
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||||
|
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||||
|
mkldnn::algorithm algorithm;
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||||
|
true,
|
||||||
|
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
||||||
|
algorithm,
|
||||||
|
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||||
|
&user_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||||
|
pool_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||||
|
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||||
|
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||||
|
auto pool_src_memory = user_src_memory;
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
|
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
|
||||||
|
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||||
|
}
|
||||||
|
auto pool_dst_memory = user_dst_memory;
|
||||||
|
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
|
||||||
|
}
|
||||||
|
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
|
||||||
|
{MKLDNN_ARG_DST, pool_dst_memory}});
|
||||||
|
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||||
|
}
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
//streams[0].submitAndWait();
|
||||||
|
|
||||||
|
if (!isNCHW) {
|
||||||
|
delete input;
|
||||||
|
delete output;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(avgpool2d) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,153 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author saudet
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
PLATFORM_IMPL(avgpool2d_bp) {
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(
|
||||||
|
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
|
||||||
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
|
int kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
int sH = INT_ARG(2); // strides height
|
||||||
|
int sW = INT_ARG(3); // strides width
|
||||||
|
int pH = INT_ARG(4); // paddings height
|
||||||
|
int pW = INT_ARG(5); // paddings width
|
||||||
|
int dH = INT_ARG(6); // dilations height
|
||||||
|
int dW = INT_ARG(7); // dilations width
|
||||||
|
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
int extraParam0 = INT_ARG(9);
|
||||||
|
int isNCHW =
|
||||||
|
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||||
|
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
|
||||||
|
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||||
|
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(
|
||||||
|
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||||
|
std::string expectedGradIShape = ShapeUtils::shapeAsString(
|
||||||
|
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||||
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||||
|
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||||
|
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||||
|
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||||
|
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
|
|
||||||
|
if (!isNCHW) {
|
||||||
|
input = new NDArray(input->permute(
|
||||||
|
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
|
gradI = new NDArray(gradI->permute(
|
||||||
|
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
|
gradO = new NDArray(gradO->permute(
|
||||||
|
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
auto poolingMode = PoolingType::AVG_POOL;
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||||
|
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||||
|
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||||
|
mkldnn::algorithm algorithm;
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||||
|
true,
|
||||||
|
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
||||||
|
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||||
|
&user_diff_src_md, &user_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
||||||
|
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
||||||
|
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||||
|
pool_padding_r);
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||||
|
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
|
||||||
|
pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||||
|
auto userB_src_memory = mkldnn::memory(user_src_md, engine, gradI->buffer());
|
||||||
|
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
|
||||||
|
auto poolB_src_memory = userB_src_memory;
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||||
|
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||||
|
}
|
||||||
|
auto poolB_dst_memory = userB_dst_memory;
|
||||||
|
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||||
|
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||||
|
}
|
||||||
|
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
|
||||||
|
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||||
|
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||||
|
}
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
if (!isNCHW) {
|
||||||
|
delete input;
|
||||||
|
delete gradI;
|
||||||
|
delete gradO;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(avgpool2d_bp) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,145 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author saudet
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
PLATFORM_IMPL(avgpool3dnew) {
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto output = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||||
|
|
||||||
|
int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
|
int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
int kW = INT_ARG(2); // filter(kernel) width
|
||||||
|
int sD = INT_ARG(3); // strides depth
|
||||||
|
int sH = INT_ARG(4); // strides height
|
||||||
|
int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
int dD = INT_ARG(9); // dilations depth
|
||||||
|
int dH = INT_ARG(10); // dilations height
|
||||||
|
int dW = INT_ARG(11); // dilations width
|
||||||
|
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
|
int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||||
|
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||||
|
"MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
|
||||||
|
input->rankOf());
|
||||||
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
||||||
|
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||||
|
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||||
|
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||||
|
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
|
||||||
|
"MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
|
||||||
|
expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||||
|
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
|
||||||
|
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
||||||
|
|
||||||
|
if (!isNCDHW) {
|
||||||
|
input = new NDArray(
|
||||||
|
input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
output = new NDArray(
|
||||||
|
output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
|
|
||||||
|
auto poolingMode = PoolingType::AVG_POOL;
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||||
|
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||||
|
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||||
|
mkldnn::algorithm algorithm;
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||||
|
extraParam0, true,
|
||||||
|
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
|
||||||
|
algorithm,
|
||||||
|
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||||
|
&user_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||||
|
pool_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||||
|
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||||
|
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||||
|
auto pool_src_memory = user_src_memory;
|
||||||
|
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
|
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
|
||||||
|
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||||
|
}
|
||||||
|
auto pool_dst_memory = user_dst_memory;
|
||||||
|
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
|
||||||
|
}
|
||||||
|
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
|
||||||
|
{MKLDNN_ARG_DST, pool_dst_memory}});
|
||||||
|
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||||
|
}
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
if (!isNCDHW) {
|
||||||
|
delete input;
|
||||||
|
delete output;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(avgpool3dnew) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,158 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
PLATFORM_IMPL(avgpool3dnew_bp) {
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(
|
||||||
|
1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
|
||||||
|
const int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
|
const int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
const int kW = INT_ARG(2); // filter(kernel) width
|
||||||
|
const int sD = INT_ARG(3); // strides depth
|
||||||
|
const int sH = INT_ARG(4); // strides height
|
||||||
|
const int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
const int dD = INT_ARG(9); // dilations depth
|
||||||
|
const int dH = INT_ARG(10); // dilations height
|
||||||
|
const int dW = INT_ARG(11); // dilations width
|
||||||
|
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
|
int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||||
|
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||||
|
"MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
||||||
|
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||||
|
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||||
|
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||||
|
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||||
|
{bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||||
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||||
|
"MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||||
|
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||||
|
"MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||||
|
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
|
if (!isNCDHW) {
|
||||||
|
input = new NDArray(input->permute(
|
||||||
|
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
gradI = new NDArray(gradI->permute(
|
||||||
|
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
gradO = new NDArray(gradO->permute(
|
||||||
|
{0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
auto poolingMode = PoolingType::AVG_POOL;
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||||
|
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||||
|
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||||
|
mkldnn::algorithm algorithm;
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||||
|
extraParam0, true,
|
||||||
|
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
|
||||||
|
algorithm,
|
||||||
|
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||||
|
&user_diff_src_md, &user_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
if (input->buffer() == nullptr) {
|
||||||
|
pool_src_md = pool_diff_src_md;
|
||||||
|
user_src_md = user_diff_src_md;
|
||||||
|
}
|
||||||
|
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||||
|
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
|
||||||
|
pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||||
|
auto userB_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
|
||||||
|
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
|
||||||
|
auto poolB_src_memory = userB_src_memory;
|
||||||
|
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||||
|
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||||
|
}
|
||||||
|
auto poolB_dst_memory = userB_dst_memory;
|
||||||
|
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||||
|
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||||
|
}
|
||||||
|
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
|
||||||
|
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||||
|
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||||
|
}
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
if (!isNCDHW) {
|
||||||
|
delete input;
|
||||||
|
delete gradI;
|
||||||
|
delete gradO;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(avgpool3dnew_bp) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,166 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author saudet
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
PLATFORM_IMPL(batchnorm_new) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto mean = INPUT_VARIABLE(1);
|
||||||
|
auto variance = INPUT_VARIABLE(2);
|
||||||
|
NDArray *gamma = nullptr;
|
||||||
|
NDArray *beta = nullptr;
|
||||||
|
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const bool applyScale = (bool) INT_ARG(0);
|
||||||
|
const bool applyOffset = (bool) INT_ARG(1);
|
||||||
|
const double epsilon = T_ARG(0);
|
||||||
|
|
||||||
|
if (applyScale)
|
||||||
|
gamma = INPUT_VARIABLE(3);
|
||||||
|
if (applyOffset)
|
||||||
|
beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
|
||||||
|
|
||||||
|
std::vector<int> axes;
|
||||||
|
if (block.numI() > 2)
|
||||||
|
for (int i = 2; i < block.numI(); ++i)
|
||||||
|
axes.push_back(INT_ARG(i));
|
||||||
|
else
|
||||||
|
axes.push_back(input->rankOf() - 1);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> shape({2, mean->lengthOf()});
|
||||||
|
NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext());
|
||||||
|
weights({0, 1, 0, 0}).assign(1.0f);
|
||||||
|
weights({1, 2, 0, 0}).assign(0.0f);
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(
|
||||||
|
empty), user_dst_md(empty);
|
||||||
|
|
||||||
|
auto norm_flag = normalization_flags::use_global_stats;
|
||||||
|
if (applyScale || applyOffset)
|
||||||
|
norm_flag |= normalization_flags::use_scale_shift;
|
||||||
|
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output,
|
||||||
|
&batchnorm_src_md, nullptr, &batchnorm_dst_md,
|
||||||
|
&user_src_md, nullptr, &user_dst_md, axes[0]);
|
||||||
|
|
||||||
|
auto batchnorm_desc = batch_normalization_forward::desc(prop_kind::forward_inference, batchnorm_src_md, epsilon, norm_flag);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto batchnorm_prim_desc = batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
|
||||||
|
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||||
|
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||||
|
auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine,
|
||||||
|
mean->buffer());
|
||||||
|
auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine,
|
||||||
|
variance->buffer());
|
||||||
|
auto batchnorm_src_memory = user_src_memory;
|
||||||
|
mkldnn::memory m(batchnorm_src_md, engine);
|
||||||
|
if (m.get_desc() != user_src_memory.get_desc()) {
|
||||||
|
batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine);
|
||||||
|
reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory,
|
||||||
|
batchnorm_src_memory);
|
||||||
|
}
|
||||||
|
auto batchnorm_dst_memory = user_dst_memory;
|
||||||
|
if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine);
|
||||||
|
}
|
||||||
|
if (applyScale || applyOffset) {
|
||||||
|
if (gamma != nullptr) {
|
||||||
|
weights({0, 1, 0, 0}).assign(gamma);
|
||||||
|
}
|
||||||
|
if (beta != nullptr) {
|
||||||
|
weights({1, 2, 0, 0}).assign(beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer());
|
||||||
|
batch_normalization_forward(batchnorm_prim_desc).execute(stream,
|
||||||
|
{{MKLDNN_ARG_SRC, batchnorm_src_memory},
|
||||||
|
{MKLDNN_ARG_MEAN, batchnorm_mean_memory},
|
||||||
|
{MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
|
||||||
|
{MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory},
|
||||||
|
{MKLDNN_ARG_DST, batchnorm_dst_memory}});
|
||||||
|
} else {
|
||||||
|
batch_normalization_forward(batchnorm_prim_desc).execute(stream,
|
||||||
|
{{MKLDNN_ARG_SRC, batchnorm_src_memory},
|
||||||
|
{MKLDNN_ARG_MEAN, batchnorm_mean_memory},
|
||||||
|
{MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
|
||||||
|
{MKLDNN_ARG_DST, batchnorm_dst_memory}});
|
||||||
|
}
|
||||||
|
if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory,
|
||||||
|
user_dst_memory);
|
||||||
|
}
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(batchnorm_new) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto mean = INPUT_VARIABLE(1);
|
||||||
|
auto variance = INPUT_VARIABLE(2);
|
||||||
|
NDArray *gamma = nullptr;
|
||||||
|
NDArray *beta = nullptr;
|
||||||
|
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const bool applyScale = (bool) INT_ARG(0);
|
||||||
|
const bool applyOffset = (bool) INT_ARG(1);
|
||||||
|
const double epsilon = T_ARG(0);
|
||||||
|
|
||||||
|
if (applyScale)
|
||||||
|
gamma = INPUT_VARIABLE(3);
|
||||||
|
if (applyOffset)
|
||||||
|
beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
|
||||||
|
|
||||||
|
std::vector<int> axes;
|
||||||
|
if (block.numI() > 2)
|
||||||
|
for (int i = 2; i < block.numI(); ++i)
|
||||||
|
axes.push_back(INT_ARG(i));
|
||||||
|
else
|
||||||
|
axes.push_back(input->rankOf() - 1);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() &&
|
||||||
|
nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) &&
|
||||||
|
axes.size() == 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,153 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author saudet
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
|
||||||
|
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
|
||||||
|
const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode,
|
||||||
|
const int isNCHW) {
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW,
|
||||||
|
indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
if(isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
|
||||||
|
empty);
|
||||||
|
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
||||||
|
empty);
|
||||||
|
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
||||||
|
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
|
||||||
|
bias, output,
|
||||||
|
&conv_src_md, nullptr, &conv_weights_md, nullptr,
|
||||||
|
&conv_bias_md, &conv_dst_md,
|
||||||
|
&user_src_md, nullptr, &user_weights_md, nullptr,
|
||||||
|
&user_bias_md, &user_dst_md,
|
||||||
|
conv_strides, conv_padding, conv_padding_r);
|
||||||
|
auto conv_desc = bias != nullptr
|
||||||
|
? convolution_forward::desc(prop_kind::forward,
|
||||||
|
algorithm::convolution_auto, conv_src_md,
|
||||||
|
conv_weights_md, conv_bias_md,
|
||||||
|
conv_dst_md, conv_strides, conv_padding,
|
||||||
|
conv_padding_r)
|
||||||
|
: convolution_forward::desc(prop_kind::forward,
|
||||||
|
algorithm::convolution_auto, conv_src_md,
|
||||||
|
conv_weights_md,
|
||||||
|
conv_dst_md, conv_strides, conv_padding,
|
||||||
|
conv_padding_r);
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
|
||||||
|
auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||||
|
auto user_weights_memory = mkldnn::memory(user_weights_md, engine,
|
||||||
|
const_cast<NDArray *>(weights)->buffer());
|
||||||
|
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||||
|
auto conv_src_memory = user_src_memory;
|
||||||
|
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
|
conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine);
|
||||||
|
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
|
||||||
|
}
|
||||||
|
auto conv_weights_memory = user_weights_memory;
|
||||||
|
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
|
||||||
|
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine);
|
||||||
|
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
|
||||||
|
conv_weights_memory);
|
||||||
|
}
|
||||||
|
auto conv_dst_memory = user_dst_memory;
|
||||||
|
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine);
|
||||||
|
}
|
||||||
|
if (bias != nullptr) {
|
||||||
|
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine,
|
||||||
|
const_cast<NDArray *>(bias)->buffer());
|
||||||
|
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
|
||||||
|
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
|
||||||
|
{MKLDNN_ARG_BIAS, conv_bias_memory},
|
||||||
|
{MKLDNN_ARG_DST, conv_dst_memory}});
|
||||||
|
} else {
|
||||||
|
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
|
||||||
|
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
|
||||||
|
{MKLDNN_ARG_DST, conv_dst_memory}});
|
||||||
|
}
|
||||||
|
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
|
||||||
|
}
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_IMPL(conv2d) {
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
auto output = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||||
|
|
||||||
|
int sH = INT_ARG(2); // strides height
|
||||||
|
int sW = INT_ARG(3); // strides width
|
||||||
|
int pH = INT_ARG(4); // paddings height
|
||||||
|
int pW = INT_ARG(5); // paddings width
|
||||||
|
int dH = INT_ARG(6); // dilations height
|
||||||
|
int dW = INT_ARG(7); // dilations width
|
||||||
|
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
|
||||||
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
|
||||||
|
|
||||||
|
conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(conv2d) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto weights = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
// conv2d is only available for float32 dtype
|
||||||
|
return block.isUseMKLDNN() && input->dataType() == nd4j::DataType::FLOAT32 &&
|
||||||
|
weights->dataType() == nd4j::DataType::FLOAT32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,243 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author saudet
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
PLATFORM_IMPL(conv2d_bp) {
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto weights = INPUT_VARIABLE(
|
||||||
|
1); // [kH, kW, iC, oC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
||||||
|
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
auto gradW = OUTPUT_VARIABLE(
|
||||||
|
1); // [kH, kW, iC, oC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
|
int kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
int sH = INT_ARG(2); // strides height
|
||||||
|
int sW = INT_ARG(3); // strides width
|
||||||
|
int pH = INT_ARG(4); // paddings height
|
||||||
|
int pW = INT_ARG(5); // paddings width
|
||||||
|
int dH = INT_ARG(6); // dilations height
|
||||||
|
int dW = INT_ARG(7); // dilations width
|
||||||
|
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||||
|
"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",
|
||||||
|
input->rankOf());
|
||||||
|
REQUIRE_TRUE(weights->rankOf() == 4, 0,
|
||||||
|
"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",
|
||||||
|
weights->rankOf());
|
||||||
|
REQUIRE_TRUE(gradO->rankOf() == 4, 0,
|
||||||
|
"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",
|
||||||
|
gradO->rankOf());
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||||
|
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
if (isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
||||||
|
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||||
|
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
||||||
|
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||||
|
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
||||||
|
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
|
||||||
|
gradB, gradO,
|
||||||
|
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
||||||
|
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
|
||||||
|
&user_src_md, &user_diff_src_md, &user_weights_md,
|
||||||
|
&user_diff_weights_md, &user_bias_md, &user_dst_md,
|
||||||
|
conv_strides, conv_padding, conv_padding_r);
|
||||||
|
auto conv_desc = gradB != nullptr
|
||||||
|
? convolution_forward::desc(prop_kind::forward,
|
||||||
|
algorithm::convolution_auto, conv_src_md,
|
||||||
|
conv_weights_md, conv_bias_md,
|
||||||
|
conv_dst_md, conv_strides, conv_padding,
|
||||||
|
conv_padding_r)
|
||||||
|
: convolution_forward::desc(prop_kind::forward,
|
||||||
|
algorithm::convolution_auto, conv_src_md,
|
||||||
|
conv_weights_md,
|
||||||
|
conv_dst_md, conv_strides, conv_padding,
|
||||||
|
conv_padding_r);
|
||||||
|
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
|
||||||
|
LaunchContext::defaultContext()->engine()));
|
||||||
|
if (gradW != nullptr) {
|
||||||
|
auto convW_desc = gradB != nullptr
|
||||||
|
? convolution_backward_weights::desc(
|
||||||
|
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
|
||||||
|
conv_dst_md, conv_strides, conv_padding, conv_padding_r)
|
||||||
|
: convolution_backward_weights::desc(
|
||||||
|
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
|
||||||
|
conv_dst_md, conv_strides, conv_padding, conv_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
|
||||||
|
conv_prim_desc);
|
||||||
|
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
|
||||||
|
const_cast<NDArray *>(input)->buffer());
|
||||||
|
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
|
||||||
|
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
|
||||||
|
const_cast<NDArray *>(gradO)->buffer());
|
||||||
|
|
||||||
|
auto convW_src_memory = userW_src_memory;
|
||||||
|
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
||||||
|
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
|
||||||
|
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
|
||||||
|
convW_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convW_weights_memory = userW_weights_memory;
|
||||||
|
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||||
|
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convW_dst_memory = userW_dst_memory;
|
||||||
|
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
||||||
|
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
|
||||||
|
convW_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gradB != nullptr) {
|
||||||
|
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
|
||||||
|
gradB->buffer());
|
||||||
|
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||||
|
{{MKLDNN_ARG_SRC, convW_src_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
|
||||||
|
} else {
|
||||||
|
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||||
|
{{MKLDNN_ARG_SRC, convW_src_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||||
|
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
|
||||||
|
userW_weights_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gradI != nullptr) {
|
||||||
|
auto convI_desc =
|
||||||
|
convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md,
|
||||||
|
conv_weights_md, conv_dst_md, conv_strides,
|
||||||
|
conv_padding, conv_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
|
||||||
|
conv_prim_desc);
|
||||||
|
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
|
||||||
|
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
|
||||||
|
const_cast<NDArray *>(weights)->buffer());
|
||||||
|
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
|
||||||
|
const_cast<NDArray *>(gradO)->buffer());
|
||||||
|
|
||||||
|
auto convI_src_memory = userI_src_memory;
|
||||||
|
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||||
|
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convI_weights_memory = userI_weights_memory;
|
||||||
|
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
||||||
|
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
|
||||||
|
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
|
||||||
|
convI_weights_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convI_dst_memory = userI_dst_memory;
|
||||||
|
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
||||||
|
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
|
||||||
|
convI_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
convolution_backward_data(convI_prim_desc).execute(stream,
|
||||||
|
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
|
||||||
|
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
|
||||||
|
|
||||||
|
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||||
|
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
|
||||||
|
userI_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
};
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(conv2d_bp) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto weights = INPUT_VARIABLE(
|
||||||
|
1); // [kH, kW, iC, oC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
||||||
|
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
auto gradW = OUTPUT_VARIABLE(
|
||||||
|
1); // [kH, kW, iC, oC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() &&
|
||||||
|
nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,167 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author saudet
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
PLATFORM_IMPL(conv3dnew) {
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||||
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto output = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||||
|
"CUSTOM CONV3D OP: rank of input array must be equal to 5, but got %i instead !",
|
||||||
|
input->rankOf());
|
||||||
|
REQUIRE_TRUE(weights->rankOf() == 5, 0,
|
||||||
|
"CUSTOM CONV3D OP: rank of weights array must be equal to 5, but got %i instead !",
|
||||||
|
weights->rankOf());
|
||||||
|
|
||||||
|
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
||||||
|
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
||||||
|
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
||||||
|
int sD = INT_ARG(3); // strides depth
|
||||||
|
int sH = INT_ARG(4); // strides height
|
||||||
|
int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
int dD = INT_ARG(9); // dilations depth
|
||||||
|
int dH = INT_ARG(10); // dilations height
|
||||||
|
int dW = INT_ARG(11); // dilations width
|
||||||
|
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||||
|
int isNCDHW =
|
||||||
|
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||||
|
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||||
|
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
|
||||||
|
"CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !",
|
||||||
|
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
if (bias)
|
||||||
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
|
||||||
|
"CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
|
||||||
|
oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
if (isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
|
||||||
|
empty);
|
||||||
|
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
||||||
|
empty);
|
||||||
|
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
|
||||||
|
isNCDHW,
|
||||||
|
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights,
|
||||||
|
nullptr, bias, output,
|
||||||
|
&conv_src_md, nullptr, &conv_weights_md, nullptr,
|
||||||
|
&conv_bias_md, &conv_dst_md,
|
||||||
|
&user_src_md, nullptr, &user_weights_md, nullptr,
|
||||||
|
&user_bias_md, &user_dst_md,
|
||||||
|
conv_strides, conv_padding, conv_padding_r);
|
||||||
|
auto conv_desc = bias != nullptr
|
||||||
|
? convolution_forward::desc(prop_kind::forward,
|
||||||
|
algorithm::convolution_auto, conv_src_md,
|
||||||
|
conv_weights_md, conv_bias_md,
|
||||||
|
conv_dst_md, conv_strides, conv_padding,
|
||||||
|
conv_padding_r)
|
||||||
|
: convolution_forward::desc(prop_kind::forward,
|
||||||
|
algorithm::convolution_auto, conv_src_md,
|
||||||
|
conv_weights_md,
|
||||||
|
conv_dst_md, conv_strides, conv_padding,
|
||||||
|
conv_padding_r);
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
|
||||||
|
auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||||
|
auto user_weights_memory = mkldnn::memory(user_weights_md, engine,
|
||||||
|
const_cast<NDArray *>(weights)->buffer());
|
||||||
|
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||||
|
auto conv_src_memory = user_src_memory;
|
||||||
|
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
|
conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine);
|
||||||
|
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
|
||||||
|
}
|
||||||
|
auto conv_weights_memory = user_weights_memory;
|
||||||
|
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
|
||||||
|
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine);
|
||||||
|
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
|
||||||
|
conv_weights_memory);
|
||||||
|
}
|
||||||
|
auto conv_dst_memory = user_dst_memory;
|
||||||
|
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine);
|
||||||
|
}
|
||||||
|
if (bias != nullptr) {
|
||||||
|
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine, bias->buffer());
|
||||||
|
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
|
||||||
|
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
|
||||||
|
{MKLDNN_ARG_BIAS, conv_bias_memory},
|
||||||
|
{MKLDNN_ARG_DST, conv_dst_memory}});
|
||||||
|
} else {
|
||||||
|
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
|
||||||
|
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
|
||||||
|
{MKLDNN_ARG_DST, conv_dst_memory}});
|
||||||
|
}
|
||||||
|
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
|
||||||
|
}
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(conv3dnew) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||||
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto output = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,263 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author saudet
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
PLATFORM_IMPL(conv3dnew_bp) {
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(
|
||||||
|
1); // [kD, kH, kW, iC, oC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
||||||
|
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
auto gradW = OUTPUT_VARIABLE(
|
||||||
|
1); // [kD, kH, kW, iC, oC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||||
|
"CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !",
|
||||||
|
input->rankOf());
|
||||||
|
REQUIRE_TRUE(weights->rankOf() == 5, 0,
|
||||||
|
"CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, but got %i instead !",
|
||||||
|
weights->rankOf());
|
||||||
|
REQUIRE_TRUE(gradO->rankOf() == 5, 0,
|
||||||
|
"CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !",
|
||||||
|
gradO->rankOf());
|
||||||
|
|
||||||
|
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
||||||
|
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
||||||
|
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
||||||
|
int sD = INT_ARG(3); // strides depth
|
||||||
|
int sH = INT_ARG(4); // strides height
|
||||||
|
int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
int dD = INT_ARG(9); // dilations depth
|
||||||
|
int dH = INT_ARG(10); // dilations height
|
||||||
|
int dW = INT_ARG(11); // dilations width
|
||||||
|
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
|
int isNDHWC =
|
||||||
|
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||||
|
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
int trueoD, trueoH, trueoW; // true output depth/height/width
|
||||||
|
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH,
|
||||||
|
dW, iD, iH, iW, isSameMode);
|
||||||
|
|
||||||
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||||
|
{bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||||
|
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||||
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||||
|
"CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !",
|
||||||
|
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
|
||||||
|
"CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !",
|
||||||
|
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
if (bias)
|
||||||
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
|
||||||
|
"CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
|
||||||
|
oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
||||||
|
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||||
|
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
||||||
|
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||||
|
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
|
||||||
|
isNDHWC,
|
||||||
|
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights,
|
||||||
|
gradW, gradB, gradO,
|
||||||
|
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
||||||
|
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
|
||||||
|
&user_src_md, &user_diff_src_md, &user_weights_md,
|
||||||
|
&user_diff_weights_md, &user_bias_md, &user_dst_md,
|
||||||
|
conv_strides, conv_padding, conv_padding_r);
|
||||||
|
auto conv_desc = gradB != nullptr
|
||||||
|
? convolution_forward::desc(prop_kind::forward,
|
||||||
|
algorithm::convolution_auto, conv_src_md,
|
||||||
|
conv_weights_md, conv_bias_md,
|
||||||
|
conv_dst_md, conv_strides, conv_padding,
|
||||||
|
conv_padding_r)
|
||||||
|
: convolution_forward::desc(prop_kind::forward,
|
||||||
|
algorithm::convolution_auto, conv_src_md,
|
||||||
|
conv_weights_md,
|
||||||
|
conv_dst_md, conv_strides, conv_padding,
|
||||||
|
conv_padding_r);
|
||||||
|
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
|
||||||
|
LaunchContext::defaultContext()->engine()));
|
||||||
|
if (gradW != nullptr) {
|
||||||
|
auto convW_desc = gradB != nullptr
|
||||||
|
? convolution_backward_weights::desc(
|
||||||
|
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
|
||||||
|
conv_dst_md, conv_strides, conv_padding, conv_padding_r)
|
||||||
|
: convolution_backward_weights::desc(
|
||||||
|
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
|
||||||
|
conv_dst_md, conv_strides, conv_padding, conv_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
|
||||||
|
conv_prim_desc);
|
||||||
|
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
|
||||||
|
const_cast<NDArray *>(input)->buffer());
|
||||||
|
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
|
||||||
|
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
|
||||||
|
const_cast<NDArray *>(gradO)->buffer());
|
||||||
|
|
||||||
|
auto convW_src_memory = userW_src_memory;
|
||||||
|
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
||||||
|
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
|
||||||
|
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
|
||||||
|
convW_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convW_weights_memory = userW_weights_memory;
|
||||||
|
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||||
|
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convW_dst_memory = userW_dst_memory;
|
||||||
|
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
||||||
|
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
|
||||||
|
convW_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gradB != nullptr) {
|
||||||
|
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
|
||||||
|
gradB->buffer());
|
||||||
|
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||||
|
{{MKLDNN_ARG_SRC, convW_src_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
|
||||||
|
} else {
|
||||||
|
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||||
|
{{MKLDNN_ARG_SRC, convW_src_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||||
|
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
|
||||||
|
userW_weights_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
if (gradI != nullptr) {
|
||||||
|
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto,
|
||||||
|
conv_diff_src_md, conv_weights_md,
|
||||||
|
conv_dst_md, conv_strides, conv_padding,
|
||||||
|
conv_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
|
||||||
|
conv_prim_desc);
|
||||||
|
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
|
||||||
|
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
|
||||||
|
const_cast<NDArray *>(weights)->buffer());
|
||||||
|
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
|
||||||
|
const_cast<NDArray *>(gradO)->buffer());
|
||||||
|
|
||||||
|
auto convI_src_memory = userI_src_memory;
|
||||||
|
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||||
|
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convI_weights_memory = userI_weights_memory;
|
||||||
|
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
||||||
|
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
|
||||||
|
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
|
||||||
|
convI_weights_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convI_dst_memory = userI_dst_memory;
|
||||||
|
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
||||||
|
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
|
||||||
|
convI_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
convolution_backward_data(convI_prim_desc).execute(stream,
|
||||||
|
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
|
||||||
|
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
|
||||||
|
|
||||||
|
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||||
|
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
|
||||||
|
userI_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(conv3dnew_bp) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(
|
||||||
|
1); // [kD, kH, kW, iC, oC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
||||||
|
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
auto gradW = OUTPUT_VARIABLE(
|
||||||
|
1); // [kD, kH, kW, iC, oC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() &&
|
||||||
|
nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,97 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author saudet
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
PLATFORM_IMPL(lrn) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn: Input rank of 4 expected, but got %i instead",
|
||||||
|
input->rankOf());
|
||||||
|
|
||||||
|
double alpha = T_ARG(1);
|
||||||
|
double beta = T_ARG(2);
|
||||||
|
double bias = T_ARG(0);
|
||||||
|
int depth = INT_ARG(0);
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty);
|
||||||
|
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescLrn(input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md,
|
||||||
|
&user_src_md, nullptr, &user_dst_md, input->rankOf() - 1);
|
||||||
|
|
||||||
|
auto lrn_desc = lrn_forward::desc(prop_kind::forward_inference, algorithm::lrn_across_channels,
|
||||||
|
lrn_src_md, (2 * depth + 1), alpha * (2 * depth + 1), beta, bias);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine);
|
||||||
|
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||||
|
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||||
|
|
||||||
|
auto lrn_src_memory = user_src_memory;
|
||||||
|
if (lrn_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
|
lrn_src_memory = mkldnn::memory(lrn_prim_desc.src_desc(), engine);
|
||||||
|
reorder(user_src_memory, lrn_src_memory).execute(stream, user_src_memory, lrn_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto lrn_dst_memory = user_dst_memory;
|
||||||
|
if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
lrn_dst_memory = mkldnn::memory(lrn_prim_desc.dst_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
lrn_forward(lrn_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, lrn_src_memory},
|
||||||
|
{MKLDNN_ARG_DST, lrn_dst_memory}});
|
||||||
|
|
||||||
|
if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
reorder(lrn_dst_memory, user_dst_memory).execute(stream, lrn_dst_memory, user_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
PLATFORM_CHECK(lrn) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,149 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author saudet
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
PLATFORM_IMPL(maxpool2d) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
|
||||||
|
input->rankOf());
|
||||||
|
|
||||||
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
|
auto argI = *(block.getIArguments());
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const auto kH = INT_ARG(0);
|
||||||
|
const auto kW = INT_ARG(1);
|
||||||
|
const auto sH = INT_ARG(2);
|
||||||
|
const auto sW = INT_ARG(3);
|
||||||
|
int pH = INT_ARG(4);
|
||||||
|
int pW = INT_ARG(5);
|
||||||
|
const auto dH = INT_ARG(6);
|
||||||
|
const auto dW = INT_ARG(7);
|
||||||
|
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
||||||
|
|
||||||
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
|
||||||
|
dH, dW);
|
||||||
|
|
||||||
|
int oH = 0;
|
||||||
|
int oW = 0;
|
||||||
|
|
||||||
|
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||||
|
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||||
|
|
||||||
|
if (!isNCHW) {
|
||||||
|
input = new NDArray(
|
||||||
|
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
|
output = new NDArray(
|
||||||
|
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
|
||||||
|
if (isSameMode)
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
const int bS = input->sizeAt(0);
|
||||||
|
const int iC = input->sizeAt(1);
|
||||||
|
const int oC = output->sizeAt(1);
|
||||||
|
|
||||||
|
auto poolingMode = PoolingType::MAX_POOL;
|
||||||
|
int extraParam0 = 1;
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||||
|
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||||
|
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||||
|
mkldnn::algorithm algorithm;
|
||||||
|
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||||
|
true,
|
||||||
|
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
||||||
|
algorithm,
|
||||||
|
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||||
|
&user_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||||
|
pool_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||||
|
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||||
|
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||||
|
|
||||||
|
auto pool_src_memory = user_src_memory;
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
|
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
|
||||||
|
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pool_dst_memory = user_dst_memory;
|
||||||
|
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
|
||||||
|
{MKLDNN_ARG_DST, pool_dst_memory}});
|
||||||
|
|
||||||
|
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
if (!isNCHW) {
|
||||||
|
delete input;
|
||||||
|
delete output;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(maxpool2d) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,178 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author saudet
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
PLATFORM_IMPL(maxpool2d_bp) {
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(
|
||||||
|
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
|
||||||
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
|
int kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
int sH = INT_ARG(2); // strides height
|
||||||
|
int sW = INT_ARG(3); // strides width
|
||||||
|
int pH = INT_ARG(4); // paddings height
|
||||||
|
int pW = INT_ARG(5); // paddings width
|
||||||
|
int dH = INT_ARG(6); // dilations height
|
||||||
|
int dW = INT_ARG(7); // dilations width
|
||||||
|
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
int extraParam0 = INT_ARG(9);
|
||||||
|
int isNCHW =
|
||||||
|
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||||
|
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
|
||||||
|
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||||
|
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(
|
||||||
|
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||||
|
std::string expectedGradIShape = ShapeUtils::shapeAsString(
|
||||||
|
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||||
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||||
|
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||||
|
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||||
|
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||||
|
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
|
|
||||||
|
if (!isNCHW) {
|
||||||
|
input = new NDArray(input->permute(
|
||||||
|
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
|
gradI = new NDArray(gradI->permute(
|
||||||
|
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
|
gradO = new NDArray(gradO->permute(
|
||||||
|
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
auto poolingMode = PoolingType::MAX_POOL;
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||||
|
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||||
|
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||||
|
mkldnn::algorithm algorithm;
|
||||||
|
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||||
|
true,
|
||||||
|
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
||||||
|
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||||
|
&user_diff_src_md, &user_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
// input is sometimes null, so we can't rely on pool_src_md being valid
|
||||||
|
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
||||||
|
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
||||||
|
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||||
|
pool_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||||
|
|
||||||
|
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||||
|
auto userB_src_memory = mkldnn::memory(user_src_md, engine, gradI->buffer());
|
||||||
|
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
|
||||||
|
|
||||||
|
auto poolB_src_memory = userB_src_memory;
|
||||||
|
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||||
|
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto poolB_dst_memory = userB_dst_memory;
|
||||||
|
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||||
|
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||||
|
auto pool_src_memory = user_src_memory;
|
||||||
|
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
|
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
|
||||||
|
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
|
||||||
|
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_desc(), engine);
|
||||||
|
|
||||||
|
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
|
||||||
|
{MKLDNN_ARG_DST, pool_dst_memory},
|
||||||
|
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory}});
|
||||||
|
// probably wrong, fix that
|
||||||
|
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
|
||||||
|
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
|
||||||
|
|
||||||
|
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||||
|
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
if (!isNCHW) {
|
||||||
|
delete input;
|
||||||
|
delete gradI;
|
||||||
|
delete gradO;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(maxpool2d_bp) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,155 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
PLATFORM_IMPL(maxpool3dnew) {
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto output = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||||
|
|
||||||
|
int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
|
int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
int kW = INT_ARG(2); // filter(kernel) width
|
||||||
|
int sD = INT_ARG(3); // strides depth
|
||||||
|
int sH = INT_ARG(4); // strides height
|
||||||
|
int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
int dD = INT_ARG(9); // dilations depth
|
||||||
|
int dH = INT_ARG(10); // dilations height
|
||||||
|
int dW = INT_ARG(11); // dilations width
|
||||||
|
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
|
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||||
|
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||||
|
"MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
|
||||||
|
input->rankOf());
|
||||||
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
||||||
|
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||||
|
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||||
|
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||||
|
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
|
||||||
|
"MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
|
||||||
|
expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||||
|
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
|
||||||
|
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
||||||
|
|
||||||
|
if (!isNCDHW) {
|
||||||
|
input = new NDArray(
|
||||||
|
input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
output = new NDArray(
|
||||||
|
output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
|
||||||
|
dW);
|
||||||
|
|
||||||
|
|
||||||
|
auto poolingMode = PoolingType::MAX_POOL;
|
||||||
|
auto extraParam0 = 1;
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||||
|
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||||
|
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||||
|
mkldnn::algorithm algorithm;
|
||||||
|
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||||
|
extraParam0, true,
|
||||||
|
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
|
||||||
|
algorithm,
|
||||||
|
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||||
|
&user_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||||
|
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||||
|
pool_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||||
|
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||||
|
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||||
|
|
||||||
|
auto pool_src_memory = user_src_memory;
|
||||||
|
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
|
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
|
||||||
|
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pool_dst_memory = user_dst_memory;
|
||||||
|
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
|
||||||
|
{MKLDNN_ARG_DST, pool_dst_memory}});
|
||||||
|
|
||||||
|
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
|
||||||
|
if (!isNCDHW) {
|
||||||
|
delete input;
|
||||||
|
delete output;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(maxpool3dnew) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,185 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
PLATFORM_IMPL(maxpool3dnew_bp) {
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(
|
||||||
|
1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
|
||||||
|
const int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
|
const int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
const int kW = INT_ARG(2); // filter(kernel) width
|
||||||
|
const int sD = INT_ARG(3); // strides depth
|
||||||
|
const int sH = INT_ARG(4); // strides height
|
||||||
|
const int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
const int dD = INT_ARG(9); // dilations depth
|
||||||
|
const int dH = INT_ARG(10); // dilations height
|
||||||
|
const int dW = INT_ARG(11); // dilations width
|
||||||
|
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
|
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||||
|
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||||
|
"MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
||||||
|
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||||
|
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||||
|
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||||
|
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||||
|
{bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||||
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||||
|
"MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||||
|
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||||
|
"MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||||
|
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
|
if (!isNCDHW) {
|
||||||
|
input = new NDArray(input->permute(
|
||||||
|
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
gradI = new NDArray(gradI->permute(
|
||||||
|
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
gradO = new NDArray(gradO->permute(
|
||||||
|
{0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
|
||||||
|
dW);
|
||||||
|
|
||||||
|
|
||||||
|
auto poolingMode = PoolingType::MAX_POOL;
|
||||||
|
auto extraParam0 = 1;
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||||
|
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||||
|
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||||
|
mkldnn::algorithm algorithm;
|
||||||
|
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||||
|
extraParam0, true,
|
||||||
|
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
|
||||||
|
algorithm,
|
||||||
|
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||||
|
&user_diff_src_md, &user_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
// input is sometimes null, so we can't rely on pool_src_md being valid
|
||||||
|
if (input->buffer() == nullptr) {
|
||||||
|
pool_src_md = pool_diff_src_md;
|
||||||
|
user_src_md = user_diff_src_md;
|
||||||
|
}
|
||||||
|
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||||
|
|
||||||
|
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||||
|
auto userB_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
|
||||||
|
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
|
||||||
|
|
||||||
|
auto poolB_src_memory = userB_src_memory;
|
||||||
|
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||||
|
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto poolB_dst_memory = userB_dst_memory;
|
||||||
|
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||||
|
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||||
|
|
||||||
|
auto pool_src_memory = user_src_memory;
|
||||||
|
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
|
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
|
||||||
|
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
|
||||||
|
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_desc(), engine);
|
||||||
|
|
||||||
|
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
|
||||||
|
{MKLDNN_ARG_DST, pool_dst_memory},
|
||||||
|
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory}});
|
||||||
|
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
|
||||||
|
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
|
||||||
|
|
||||||
|
|
||||||
|
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||||
|
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
if (!isNCDHW) {
|
||||||
|
delete input;
|
||||||
|
delete gradI;
|
||||||
|
delete gradO;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(maxpool3dnew_bp) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,404 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author saudet
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <mkldnn_types.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace mkldnnUtils {
|
||||||
|
void getMKLDNNMemoryDescPool2d(
|
||||||
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||||
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
||||||
|
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||||
|
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||||
|
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
|
||||||
|
mkldnn::memory::dims pool_src_tz = { bS, iC, iH, iW };
|
||||||
|
mkldnn::memory::dims pool_dst_tz = { bS, oC, oH, oW };
|
||||||
|
|
||||||
|
pool_strides = { sH, sW };
|
||||||
|
pool_kernel = { kH, kW };
|
||||||
|
pool_padding = { pH, pW };
|
||||||
|
pool_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
||||||
|
(oW - 1) * sW - iW + kW - pW };
|
||||||
|
|
||||||
|
algorithm = poolingMode == 0 ? algorithm::pooling_max
|
||||||
|
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
|
||||||
|
: algorithm::pooling_avg_include_padding;
|
||||||
|
auto type = mkldnn::memory::data_type::f32;
|
||||||
|
auto format = isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
|
||||||
|
auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||||
|
|
||||||
|
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
||||||
|
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||||
|
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||||
|
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
||||||
|
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||||
|
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
||||||
|
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
|
||||||
|
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
void getMKLDNNMemoryDescPool3d(
|
||||||
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
||||||
|
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
||||||
|
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||||
|
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||||
|
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
|
||||||
|
mkldnn::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
|
||||||
|
mkldnn::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
|
||||||
|
|
||||||
|
pool_strides = { sD, sH, sW };
|
||||||
|
pool_kernel = { kD, kH, kW };
|
||||||
|
pool_padding = { pD, pH, pW };
|
||||||
|
pool_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
||||||
|
(oH - 1) * sH - iH + kH - pH,
|
||||||
|
(oW - 1) * sW - iW + kW - pW };
|
||||||
|
|
||||||
|
algorithm = poolingMode == 0 ? algorithm::pooling_max
|
||||||
|
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
|
||||||
|
: algorithm::pooling_avg_include_padding;
|
||||||
|
auto type = mkldnn::memory::data_type::f32;
|
||||||
|
auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc;
|
||||||
|
auto supposed_to_be_any_format = mkldnn::memory::format_tag::nCdhw8c; // doesn't work with "any"
|
||||||
|
|
||||||
|
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
||||||
|
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||||
|
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||||
|
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
||||||
|
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||||
|
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
||||||
|
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
|
||||||
|
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void getMKLDNNMemoryDescConv2d(
|
||||||
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
|
||||||
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||||
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||||
|
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
||||||
|
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||||
|
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||||
|
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r) {
|
||||||
|
mkldnn::memory::dims conv_src_tz = { bS, iC, iH, iW };
|
||||||
|
mkldnn::memory::dims conv_weights_tz = { oC, iC, kH, kW };
|
||||||
|
mkldnn::memory::dims conv_bias_tz = { oC };
|
||||||
|
mkldnn::memory::dims conv_dst_tz = { bS, oC, oH, oW };
|
||||||
|
|
||||||
|
conv_strides = { sH, sW };
|
||||||
|
conv_padding = { pH, pW };
|
||||||
|
conv_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
||||||
|
(oW - 1) * sW - iW + kW - pW };
|
||||||
|
|
||||||
|
auto type = mkldnn::memory::data_type::f32;
|
||||||
|
auto format = isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
|
||||||
|
auto formatw = mkldnn::memory::format_tag::hwio;
|
||||||
|
|
||||||
|
if (src != nullptr && conv_src_md != nullptr) {
|
||||||
|
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
|
||||||
|
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
||||||
|
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||||
|
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
||||||
|
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
|
||||||
|
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
||||||
|
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (weights != nullptr && conv_weights_md != nullptr) {
|
||||||
|
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
|
||||||
|
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
||||||
|
user_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = hwio"
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3];
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2];
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
||||||
|
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
|
||||||
|
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
||||||
|
user_diff_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = hwio"
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3];
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2];
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bias != nullptr && conv_bias_md != nullptr) {
|
||||||
|
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::any);
|
||||||
|
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::x);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst != nullptr && conv_dst_md != nullptr) {
|
||||||
|
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format_tag::any);
|
||||||
|
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format);
|
||||||
|
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void getMKLDNNMemoryDescConv3d(
|
||||||
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
|
||||||
|
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||||
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||||
|
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
||||||
|
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||||
|
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||||
|
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r) {
|
||||||
|
mkldnn::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
|
||||||
|
mkldnn::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
|
||||||
|
mkldnn::memory::dims conv_bias_tz = { oC };
|
||||||
|
mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
|
||||||
|
|
||||||
|
conv_strides = { sD, sH, sW };
|
||||||
|
conv_padding = { pD, pH, pW };
|
||||||
|
conv_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
||||||
|
(oH - 1) * sH - iH + kH - pH,
|
||||||
|
(oW - 1) * sW - iW + kW - pW };
|
||||||
|
|
||||||
|
auto type = mkldnn::memory::data_type::f32;
|
||||||
|
auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc;
|
||||||
|
auto formatw = mkldnn::memory::format_tag::dhwio;
|
||||||
|
|
||||||
|
if (src != nullptr && conv_src_md != nullptr) {
|
||||||
|
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
|
||||||
|
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
||||||
|
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||||
|
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
||||||
|
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
|
||||||
|
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
||||||
|
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (weights != nullptr && conv_weights_md != nullptr) {
|
||||||
|
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
|
||||||
|
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
||||||
|
user_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = dhwio"
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4];
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3];
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[4] = weights->stridesOf()[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
||||||
|
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
|
||||||
|
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
||||||
|
user_diff_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = dhwio"
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4];
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3];
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[4] = diff_weights->stridesOf()[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bias != nullptr && conv_bias_md != nullptr) {
|
||||||
|
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::any);
|
||||||
|
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::x);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst != nullptr && conv_dst_md != nullptr) {
|
||||||
|
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format_tag::any);
|
||||||
|
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format);
|
||||||
|
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||||
|
mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
|
||||||
|
const Nd4jLong* shape = src->getShapeInfo();
|
||||||
|
Nd4jLong rank = shape[0];
|
||||||
|
Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||||
|
Nd4jLong dim2 = axis >= 2 ? 1 : 2;
|
||||||
|
Nd4jLong dim3 = axis >= 3 ? 2 : 3;
|
||||||
|
mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||||
|
|
||||||
|
auto type = mkldnn::memory::data_type::f32;
|
||||||
|
auto format = mkldnn::memory::format_tag::nchw;
|
||||||
|
auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||||
|
|
||||||
|
if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
|
||||||
|
*batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
||||||
|
user_src_md->data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||||
|
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
|
||||||
|
*batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
||||||
|
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
|
||||||
|
*batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
||||||
|
user_dst_md->data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||||
|
mkldnn::memory::desc* lrn_src_md, mkldnn::memory::desc* lrn_diff_src_md, mkldnn::memory::desc* lrn_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
|
||||||
|
const Nd4jLong* shape = src->getShapeInfo();
|
||||||
|
long rank = shape[0];
|
||||||
|
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||||
|
long dim2 = axis >= 2 ? 1 : 2;
|
||||||
|
long dim3 = axis >= 3 ? 2 : 3;
|
||||||
|
mkldnn::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||||
|
|
||||||
|
auto type = mkldnn::memory::data_type::f32;
|
||||||
|
auto format = axis == 1 ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
|
||||||
|
auto supposed_to_be_any_format = format; // doesn't work with "any"
|
||||||
|
|
||||||
|
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
|
||||||
|
*lrn_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
|
||||||
|
user_src_md->data.format_kind = mkldnn_blocked;
|
||||||
|
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||||
|
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
|
||||||
|
*lrn_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
|
||||||
|
user_diff_src_md->data.format_kind = mkldnn_blocked;
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
|
||||||
|
*lrn_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
|
||||||
|
user_dst_md->data.format_kind = mkldnn_blocked;
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mkldnn::engine& getEngine(void *ptr) {
|
||||||
|
auto eng = reinterpret_cast<mkldnn::engine*>(ptr);
|
||||||
|
return *eng;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,124 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author saudet
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef DEV_TESTS_MKLDNNUTILS_H
|
||||||
|
#define DEV_TESTS_MKLDNNUTILS_H
|
||||||
|
|
||||||
|
#include <NativeOps.h>
|
||||||
|
#include <NDArray.h>
|
||||||
|
#include <mkldnn.hpp>
|
||||||
|
#include <MKLDNNStream.h>
|
||||||
|
#include <graph/Context.h>
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace nd4j{
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
/**
|
||||||
|
* Here we actually declare our platform helpers
|
||||||
|
*/
|
||||||
|
DECLARE_PLATFORM(conv2d);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(conv2d_bp);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(avgpool2d);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(avgpool2d_bp);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(maxpool2d);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(maxpool2d_bp);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(conv3dnew);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(conv3dnew_bp);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(maxpool3dnew);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(maxpool3dnew_bp);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(avgpool3dnew);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(avgpool3dnew_bp);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(lrn);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(batchnorm_new);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace mkldnnUtils {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility methods for MKLDNN
|
||||||
|
*/
|
||||||
|
void getMKLDNNMemoryDescConv2d(
|
||||||
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
|
||||||
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||||
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||||
|
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
||||||
|
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||||
|
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||||
|
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
|
||||||
|
|
||||||
|
void getMKLDNNMemoryDescConv3d(
|
||||||
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
|
||||||
|
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||||
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||||
|
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
||||||
|
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||||
|
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||||
|
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
|
||||||
|
|
||||||
|
void getMKLDNNMemoryDescPool2d(
|
||||||
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||||
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
||||||
|
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||||
|
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||||
|
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
|
||||||
|
|
||||||
|
void getMKLDNNMemoryDescPool3d(
|
||||||
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
||||||
|
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
||||||
|
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||||
|
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||||
|
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
|
||||||
|
|
||||||
|
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||||
|
mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis);
|
||||||
|
|
||||||
|
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||||
|
mkldnn::memory::desc* lrn_src_md, mkldnn::memory::desc* lrn_diff_src_md, mkldnn::memory::desc* lrn_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis);
|
||||||
|
|
||||||
|
mkldnn::engine& getEngine(void *ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#endif //DEV_TESTS_MKLDNNUTILS_H
|
|
@ -0,0 +1,45 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SD_PLATFORM_BOILERPLATE_H
|
||||||
|
#define SD_PLATFORM_BOILERPLATE_H
|
||||||
|
|
||||||
|
|
||||||
|
#define DECLARE_PLATFORM(NAME) class ND4J_EXPORT PLATFORM_##NAME : public PlatformHelper {\
|
||||||
|
public: \
|
||||||
|
PLATFORM_##NAME() : PlatformHelper(#NAME) { } \
|
||||||
|
bool isUsable(graph::Context &context) override; \
|
||||||
|
Nd4jStatus invokeHelper(graph::Context &context) override; \
|
||||||
|
};
|
||||||
|
|
||||||
|
#define PLATFORM_IMPL(NAME) struct ND4J_EXPORT __registratorPlatformHelper_##NAME { \
|
||||||
|
__registratorPlatformHelper_##NAME() { \
|
||||||
|
auto helper = new PLATFORM_##NAME(); \
|
||||||
|
OpRegistrator::getInstance()->registerHelper(helper); \
|
||||||
|
} \
|
||||||
|
}; \
|
||||||
|
static __registratorPlatformHelper_##NAME platformHelper_##NAME; \
|
||||||
|
Nd4jStatus PLATFORM_##NAME::invokeHelper(nd4j::graph::Context &block)
|
||||||
|
|
||||||
|
|
||||||
|
#define PLATFORM_CHECK(NAME) bool PLATFORM_##NAME::isUsable(graph::Context &block)
|
||||||
|
|
||||||
|
|
||||||
|
#endif //SD_PLATFORM_BOILERPLATE_H
|
|
@ -21,7 +21,7 @@
|
||||||
#include <iosfwd>
|
#include <iosfwd>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <pointercast.h>
|
#include <pointercast.h>
|
||||||
#if defined(__INTEL_COMPILER) || defined(__F16C__)
|
#if defined(__INTEL_COMPILER) || defined(SD_F16C)
|
||||||
#include <immintrin.h>
|
#include <immintrin.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -122,7 +122,7 @@ static local_def unsigned short hneg(unsigned short h) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#if defined(__INTEL_COMPILER) || defined(__F16C__)
|
#if defined(__INTEL_COMPILER) || defined(SD_F16C)
|
||||||
//_Pragma("omp declare simd") inline
|
//_Pragma("omp declare simd") inline
|
||||||
local_def float cpu_ihalf2float(ihalf h) {
|
local_def float cpu_ihalf2float(ihalf h) {
|
||||||
return _cvtsh_ss(h.getX());
|
return _cvtsh_ss(h.getX());
|
||||||
|
@ -157,7 +157,7 @@ local_def float cpu_ihalf2float(ihalf h) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__INTEL_COMPILER) || defined(__F16C__)
|
#if defined(__INTEL_COMPILER) || defined(SD_F16C)
|
||||||
//_Pragma("omp declare simd") inline
|
//_Pragma("omp declare simd") inline
|
||||||
local_def ihalf cpu_float2ihalf_rn(float f) {
|
local_def ihalf cpu_float2ihalf_rn(float f) {
|
||||||
ihalf ret;
|
ihalf ret;
|
||||||
|
|
|
@ -74,6 +74,7 @@
|
||||||
<libnd4j.compute></libnd4j.compute>
|
<libnd4j.compute></libnd4j.compute>
|
||||||
<libnd4j.classifier>${libnd4j.platform}</libnd4j.classifier>
|
<libnd4j.classifier>${libnd4j.platform}</libnd4j.classifier>
|
||||||
<libnd4j.buildthreads></libnd4j.buildthreads>
|
<libnd4j.buildthreads></libnd4j.buildthreads>
|
||||||
|
<libnd4j.helper></libnd4j.helper>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<build>
|
<build>
|
||||||
|
@ -175,6 +176,8 @@
|
||||||
<argument>${libnd4j.tests}</argument>
|
<argument>${libnd4j.tests}</argument>
|
||||||
<argument>-j</argument>
|
<argument>-j</argument>
|
||||||
<argument>${libnd4j.buildthreads}</argument>
|
<argument>${libnd4j.buildthreads}</argument>
|
||||||
|
<argument>-h</argument>
|
||||||
|
<argument>${libnd4j.helper}</argument>
|
||||||
</buildCommand>
|
</buildCommand>
|
||||||
<workingDirectory>${project.basedir}</workingDirectory>
|
<workingDirectory>${project.basedir}</workingDirectory>
|
||||||
</configuration>
|
</configuration>
|
||||||
|
@ -391,5 +394,30 @@
|
||||||
</plugins>
|
</plugins>
|
||||||
</build>
|
</build>
|
||||||
</profile>
|
</profile>
|
||||||
|
<!-- Profiles to set the default libnd4j.helper property, example: mkdnn -->
|
||||||
|
<profile>
|
||||||
|
<id>libnd4j-helper-avx2</id>
|
||||||
|
<activation>
|
||||||
|
<property>
|
||||||
|
<name>libnd4j.extension</name>
|
||||||
|
<value>avx2</value>
|
||||||
|
</property>
|
||||||
|
</activation>
|
||||||
|
<properties>
|
||||||
|
<libnd4j.helper>mkldnn</libnd4j.helper>
|
||||||
|
</properties>
|
||||||
|
</profile>
|
||||||
|
<profile>
|
||||||
|
<id>libnd4j-helper-avx512</id>
|
||||||
|
<activation>
|
||||||
|
<property>
|
||||||
|
<name>libnd4j.extension</name>
|
||||||
|
<value>avx512</value>
|
||||||
|
</property>
|
||||||
|
</activation>
|
||||||
|
<properties>
|
||||||
|
<libnd4j.helper>mkldnn</libnd4j.helper>
|
||||||
|
</properties>
|
||||||
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -131,7 +131,7 @@ endforeach(TMP_PATH)
|
||||||
|
|
||||||
if (CPU_BLAS)
|
if (CPU_BLAS)
|
||||||
add_executable(runtests ${TEST_SOURCES})
|
add_executable(runtests ${TEST_SOURCES})
|
||||||
target_link_libraries(runtests ${LIBND4J_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} gtest gtest_main)
|
target_link_libraries(runtests ${LIBND4J_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} gtest gtest_main)
|
||||||
elseif(CUDA_BLAS)
|
elseif(CUDA_BLAS)
|
||||||
CUDA_ADD_EXECUTABLE(runtests ${TEST_SOURCES})
|
CUDA_ADD_EXECUTABLE(runtests ${TEST_SOURCES})
|
||||||
target_link_libraries(runtests ${LIBND4J_NAME} ${CUDA_LIBRARIES} gtest gtest_main)
|
target_link_libraries(runtests ${LIBND4J_NAME} ${CUDA_LIBRARIES} gtest gtest_main)
|
||||||
|
|
|
@ -32,6 +32,10 @@
|
||||||
#include <ops/declarable/helpers/col2im.h>
|
#include <ops/declarable/helpers/col2im.h>
|
||||||
#include <PointersManager.h>
|
#include <PointersManager.h>
|
||||||
|
|
||||||
|
#ifdef HAVE_MKLDNN
|
||||||
|
#include <ops/declarable/platform/mkldnn/mkldnnUtils.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
using namespace nd4j;
|
using namespace nd4j;
|
||||||
using namespace nd4j::graph;
|
using namespace nd4j::graph;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <initializer_list>
|
||||||
|
|
||||||
|
#ifdef HAVE_MKLDNN
|
||||||
|
|
||||||
|
#include <ops/declarable/platform/mkldnn/mkldnnUtils.h>
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
class MklDnnTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
static void printer(std::initializer_list<nd4j::ops::platforms::PlatformHelper*> helpers) {
|
||||||
|
|
||||||
|
for (auto v:helpers) {
|
||||||
|
nd4j_printf("Initialized [%s]\n", v->name().c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(MklDnnTests, helpers_includer) {
|
||||||
|
// we need this block, to make sure all helpers are still available within binary, and not optimized out by linker
|
||||||
|
#ifdef HAVE_MKLDNN
|
||||||
|
nd4j::ops::platforms::PLATFORM_conv2d conv2d;
|
||||||
|
nd4j::ops::platforms::PLATFORM_conv2d_bp conv2d_bp;
|
||||||
|
|
||||||
|
nd4j::ops::platforms::PLATFORM_conv2d conv3d;
|
||||||
|
nd4j::ops::platforms::PLATFORM_conv2d_bp conv3d_bp;
|
||||||
|
|
||||||
|
nd4j::ops::platforms::PLATFORM_avgpool2d avgpool2d;
|
||||||
|
nd4j::ops::platforms::PLATFORM_avgpool2d_bp avgpool2d_bp;
|
||||||
|
|
||||||
|
nd4j::ops::platforms::PLATFORM_maxpool2d maxpool2d;
|
||||||
|
nd4j::ops::platforms::PLATFORM_maxpool2d_bp maxpool2d_bp;
|
||||||
|
|
||||||
|
nd4j::ops::platforms::PLATFORM_avgpool3dnew avgpool3d;
|
||||||
|
nd4j::ops::platforms::PLATFORM_avgpool3dnew_bp avgpool3d_bp;
|
||||||
|
|
||||||
|
nd4j::ops::platforms::PLATFORM_maxpool3dnew maxpool3d;
|
||||||
|
nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp maxpool3d_bp;
|
||||||
|
|
||||||
|
nd4j::ops::platforms::PLATFORM_lrn lrn;
|
||||||
|
nd4j::ops::platforms::PLATFORM_batchnorm_new batchnorm;
|
||||||
|
|
||||||
|
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm});
|
||||||
|
#endif
|
||||||
|
}
|
|
@ -21,7 +21,7 @@ endif()
|
||||||
# OPTIONAL MKL-DNN
|
# OPTIONAL MKL-DNN
|
||||||
if ("${BUILD_MKLDNN}")
|
if ("${BUILD_MKLDNN}")
|
||||||
# Download and unpack mkl-dnn at configure time
|
# Download and unpack mkl-dnn at configure time
|
||||||
configure_file(./CMakeLists.txt.in mkldnn-download/CMakeLists.txt)
|
configure_file(../../CMakeLists.txt.mkldnn.in mkldnn-download/CMakeLists.txt)
|
||||||
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
|
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
|
||||||
RESULT_VARIABLE result
|
RESULT_VARIABLE result
|
||||||
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-download )
|
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-download )
|
||||||
|
@ -40,7 +40,7 @@ if ("${BUILD_MKLDNN}")
|
||||||
EXCLUDE_FROM_ALL)
|
EXCLUDE_FROM_ALL)
|
||||||
set(mkldnn_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src)
|
set(mkldnn_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src)
|
||||||
set(HAVE_MKLDNN 1)
|
set(HAVE_MKLDNN 1)
|
||||||
add_definitions(-DHAVE_MKLDNN=true)
|
add_definitions("-DHAVE_MKLDNN")
|
||||||
include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_SOURCE_DIR}/external/mklml_lnx_2019.0.3.20190220/include ${mkldnn_SOURCE_DIR})
|
include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_SOURCE_DIR}/external/mklml_lnx_2019.0.3.20190220/include ${mkldnn_SOURCE_DIR})
|
||||||
set(MKLDNN mkldnn)
|
set(MKLDNN mkldnn)
|
||||||
endif()
|
endif()
|
||||||
|
@ -131,7 +131,7 @@ else()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (${F16C})
|
if (${F16C})
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c -D__F16C__=true")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c -DSD_F16C=true")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -177,6 +177,7 @@ if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}
|
||||||
message(FATAL_ERROR "You need at least GCC 4.9")
|
message(FATAL_ERROR "You need at least GCC 4.9")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
message("Looking for OpenMP")
|
||||||
find_package(OpenMP)
|
find_package(OpenMP)
|
||||||
if (OPENMP_FOUND)
|
if (OPENMP_FOUND)
|
||||||
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||||
|
@ -185,10 +186,11 @@ else()
|
||||||
message("OPENMP NOT FOUND")
|
message("OPENMP NOT FOUND")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("${OPENBLAS}" OR CMAKE_BUILD_TYPE STREQUAL "Release")
|
if ("${OPENBLAS}" OR CMAKE_BUILD_TYPE STREQUAL "Release" OR "${BUILD_MKLDNN}")
|
||||||
find_package(BLAS)
|
message("Looking for BLAS")
|
||||||
|
find_package(BLAS REQUIRED)
|
||||||
if (BLAS_FOUND)
|
if (BLAS_FOUND)
|
||||||
message("Found external BLAS implementation...")
|
message("Found external BLAS library: ${BLAS_LIBRARIES}")
|
||||||
add_definitions(-D__EXTERNAL_BLAS__=true)
|
add_definitions(-D__EXTERNAL_BLAS__=true)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
@ -201,13 +203,18 @@ file(GLOB_RECURSE ARRAY_SOURCES false ../../include/array/*.cpp ../../include/ar
|
||||||
file(GLOB_RECURSE MEMORY_SOURCES false ../../include/memory/*.cpp ../../include/memory/*.h)
|
file(GLOB_RECURSE MEMORY_SOURCES false ../../include/memory/*.cpp ../../include/memory/*.h)
|
||||||
file(GLOB_RECURSE GRAPH_SOURCES false ../../include/graph/*.cpp ../../include/graph/*.h)
|
file(GLOB_RECURSE GRAPH_SOURCES false ../../include/graph/*.cpp ../../include/graph/*.h)
|
||||||
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../../include/ops/declarable/generic/*.cpp)
|
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../../include/ops/declarable/generic/*.cpp)
|
||||||
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../../include/ops/declarable/helpers/cpu/*.cpp ../../include/ops/declarable/helpers/impl/*.cpp)
|
file(GLOB_RECURSE CUSTOMOPS_GENERIC_SOURCES false ../../include/ops/declarable/helpers/cpu/*.cpp ../../include/ops/declarable/helpers/impl/*.cpp)
|
||||||
file(GLOB_RECURSE OPS_SOURCES false ../../include/ops/impl/*.cpp ../../include/ops/declarable/impl/*.cpp ../../include/ops/*.h)
|
file(GLOB_RECURSE OPS_SOURCES false ../../include/ops/impl/*.cpp ../../include/ops/declarable/impl/*.cpp ../../include/ops/*.h)
|
||||||
file(GLOB_RECURSE INDEXING_SOURCES false ../../include/indexing/*.cpp ../../include/indexing/*.h)
|
file(GLOB_RECURSE INDEXING_SOURCES false ../../include/indexing/*.cpp ../../include/indexing/*.h)
|
||||||
file(GLOB_RECURSE HELPERS_SOURCES false ../../include/helpers/*.cpp ../../include/helpers/*.h)
|
file(GLOB_RECURSE HELPERS_SOURCES false ../../include/helpers/*.cpp)
|
||||||
file(GLOB_RECURSE LOOPS_SOURCES false ../../include/loops/*.cpp ../../include/loops/*.h)
|
file(GLOB_RECURSE LOOPS_SOURCES false ../../include/loops/*.cpp ../../include/loops/*.h)
|
||||||
|
|
||||||
message("CPU BLAS")
|
# optionally build mkldnn
|
||||||
|
if ("${BUILD_MKLDNN}")
|
||||||
|
file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../../include/ops/declarable/platform/mkldnn/*.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
message("CPU backend")
|
||||||
add_definitions(-D__CPUBLAS__=true)
|
add_definitions(-D__CPUBLAS__=true)
|
||||||
|
|
||||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW) AND NOT(APPLE))
|
if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW) AND NOT(APPLE))
|
||||||
|
@ -216,8 +223,37 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW) AND NOT(APPLE))
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
# this function strips path from file name, basically making up short file name, i.e. file.cpp
|
||||||
|
function(SHORTNAME LONG_NAME OUTPUT)
|
||||||
|
SET(_TMP_STR "")
|
||||||
|
string (REGEX REPLACE ".*/" "" _TMP_STR "${LONG_NAME}")
|
||||||
|
set (${OUTPUT} "${_TMP_STR}" PARENT_SCOPE)
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
# now we ned to join two lists
|
||||||
|
# first of all we'll build truncated list of files in platform sources
|
||||||
|
# and list of priority implementations from platform helpers
|
||||||
|
#set(CUSTOMOPS_HELPERS_SOURCES "")
|
||||||
|
#set(SHORT_NAMES "")
|
||||||
|
#foreach(LONG_NAME ${CUSTOMOPS_PLATFORM_SOURCES})
|
||||||
|
# SHORTNAME("${LONG_NAME}" "SHORT_NAME")
|
||||||
|
# set(CUSTOMOPS_HELPERS_SOURCES ${CUSTOMOPS_HELPERS_SOURCES} ${LONG_NAME})
|
||||||
|
# set(SHORT_NAMES ${SHORT_NAMES} ${SHORT_NAME})
|
||||||
|
#endforeach()
|
||||||
|
|
||||||
|
# now we're going to filter generic helpers, to exclude platform implementations
|
||||||
|
#foreach(LONG_NAME ${CUSTOMOPS_GENERIC_SOURCES})
|
||||||
|
# SHORTNAME("${LONG_NAME}" "SHORT_NAME")
|
||||||
|
|
||||||
|
# and now we add this op ONLY if it wasn't announced in platform helpers
|
||||||
|
# string(FIND "${SHORT_NAMES}" "${SHORT_NAME}" "LOC")
|
||||||
|
# if (${LOC} EQUAL -1)
|
||||||
|
# set(CUSTOMOPS_HELPERS_SOURCES ${CUSTOMOPS_HELPERS_SOURCES} ${LONG_NAME})
|
||||||
|
# endif()
|
||||||
|
#endforeach()
|
||||||
|
|
||||||
|
|
||||||
file(GLOB_RECURSE TEST_SOURCES false ../layers_tests/*.cpp ../layers_tests/*.h)
|
file(GLOB_RECURSE TEST_SOURCES false ../layers_tests/*.cpp ../layers_tests/*.h)
|
||||||
# file(GLOB_RECURSE TEST_SOURCES false ../layers_tests/DeclarableOpsTests6.cpp ../layers_tests/*.h)
|
|
||||||
|
|
||||||
|
|
||||||
# Filter out any source files from */CMakeFiles/* paths. these tend to cause problems such a multiple main definitions.
|
# Filter out any source files from */CMakeFiles/* paths. these tend to cause problems such a multiple main definitions.
|
||||||
|
@ -234,7 +270,7 @@ add_executable(runtests ${LOOPS_SOURCES} ../../blas/cpu/NativeOps.cpp ../../blas
|
||||||
../../blas/cpu/NativeOpExecutioner.cpp ../../blas/cpu/NDArray.cpp ../../blas/cpu/NDArrayFactory.cpp
|
../../blas/cpu/NativeOpExecutioner.cpp ../../blas/cpu/NDArray.cpp ../../blas/cpu/NDArrayFactory.cpp
|
||||||
../../include/cnpy/cnpy.cpp ../../include/nd4jmemset.h ../../include/nd4jmalloc.h
|
../../include/cnpy/cnpy.cpp ../../include/nd4jmemset.h ../../include/nd4jmalloc.h
|
||||||
../../blas/Environment.cpp ../../blas/Environment.h ${EXEC_SOURCES} ${HELPERS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
../../blas/Environment.cpp ../../blas/Environment.h ${EXEC_SOURCES} ${HELPERS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_HELPERS_SOURCES}
|
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
||||||
${OPS_SOURCES} ${TEST_SOURCES} ${PERF_SOURCES})
|
${OPS_SOURCES} ${TEST_SOURCES} ${PERF_SOURCES})
|
||||||
|
|
||||||
target_link_libraries(runtests gtest ${MKLDNN} gtest_main ${BLAS_LIBRARIES})
|
target_link_libraries(runtests gtest ${MKLDNN} gtest_main ${BLAS_LIBRARIES})
|
||||||
|
|
|
@ -1153,4 +1153,10 @@ public interface NativeOps {
|
||||||
String lastErrorMessage();
|
String lastErrorMessage();
|
||||||
|
|
||||||
boolean isBlasVersionMatches(int major, int minor, int build);
|
boolean isBlasVersionMatches(int major, int minor, int build);
|
||||||
|
|
||||||
|
int binaryLevel();
|
||||||
|
int optimalLevel();
|
||||||
|
|
||||||
|
boolean isMinimalRequirementsMet();
|
||||||
|
boolean isOptimalRequirementsMet();
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,6 +68,7 @@ import org.bytedeco.javacpp.tools.InfoMapper;
|
||||||
//"op_boilerplate.h",
|
//"op_boilerplate.h",
|
||||||
"ops/InputType.h",
|
"ops/InputType.h",
|
||||||
"ops/declarable/OpDescriptor.h",
|
"ops/declarable/OpDescriptor.h",
|
||||||
|
"ops/declarable/PlatformHelper.h",
|
||||||
"ops/declarable/BroadcastableOp.h",
|
"ops/declarable/BroadcastableOp.h",
|
||||||
"helpers/OpArgsHolder.h",
|
"helpers/OpArgsHolder.h",
|
||||||
"ops/declarable/DeclarableOp.h",
|
"ops/declarable/DeclarableOp.h",
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue