Platform helpers (#8216)

* platform helpers draft

Signed-off-by: raver119 <raver119@gmail.com>

* typo

Signed-off-by: raver119 <raver119@gmail.com>

* disable platform cmake

Signed-off-by: raver119 <raver119@gmail.com>

* another draft

Signed-off-by: raver119 <raver119@gmail.com>

* mkldnn convolution refactored

Signed-off-by: raver119 <raver119@gmail.com>

* minor tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* one more safety check

Signed-off-by: raver119 <raver119@gmail.com>

* prototype works

Signed-off-by: raver119 <raver119@gmail.com>

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* force static library mode for mkldnn

Signed-off-by: raver119 <raver119@gmail.com>

* - ismax fix
- experimental arg fix
- don't enforce openblas on Apple hardware

Signed-off-by: raver119 <raver119@gmail.com>

* bunch of small fixes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* declare concurrent

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* - MKLDNN version upgrade to 1.0.2
- avgpool2d/maxpool2d APIs update

Signed-off-by: raver119 <raver119@gmail.com>

* - avgpool2d_bp/maxpool2d_bp APIs update

Signed-off-by: raver119 <raver119@gmail.com>

* - conv2d/batchnorm APIs update

Signed-off-by: raver119 <raver119@gmail.com>

* - lrn/conv2d_bp/conv3d/conv3d_bp APIs update

Signed-off-by: raver119 <raver119@gmail.com>

* all ops converted to MKLDNN 1.x

Signed-off-by: raver119 <raver119@gmail.com>

* bunch of tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* namespace for platform helpers

Signed-off-by: raver119 <raver119@gmail.com>

* make sure platform helpers aren't opimized out

Signed-off-by: raver119 <raver119@gmail.com>

* build cpu_features on x86 systems

Signed-off-by: raver119 <raver119@gmail.com>

* build cpu_features on x86 systems

Signed-off-by: raver119 <raver119@gmail.com>

* more of cpu_features

Signed-off-by: raver119 <raver119@gmail.com>

* - mkldnn removed from java
- cpu_features checks in CpuNDArrayFactory

Signed-off-by: raver119 <raver119@gmail.com>

* F16C definition renamed

Signed-off-by: raver119 <raver119@gmail.com>

* some mkldnn rearrangements

Signed-off-by: raver119 <raver119@gmail.com>

* check supported instructions before doing anything

Signed-off-by: raver119 <raver119@gmail.com>

* typo

Signed-off-by: raver119 <raver119@gmail.com>

* missied impl

Signed-off-by: raver119 <raver119@gmail.com>

* BUILD_PIC option

Signed-off-by: raver119 <raver119@gmail.com>

* conv2d fix

Signed-off-by: raver119 <raver119@gmail.com>

* avgpool3d fix

Signed-off-by: raver119 <raver119@gmail.com>

* avgpool3d_bp fix

Signed-off-by: raver119 <raver119@gmail.com>

* avgpool2d_bp leak fix

Signed-off-by: raver119 <raver119@gmail.com>

* avgpool3d_bp leak fix

Signed-off-by: raver119 <raver119@gmail.com>

* maxpool bp leaks fixed

Signed-off-by: raver119 <raver119@gmail.com>

* printf removed

Signed-off-by: raver119 <raver119@gmail.com>

* batchnorm fix

Signed-off-by: raver119 <raver119@gmail.com>

* AVX warning/error polishing

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More polish

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Polish

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* remove previous MKL-DNN support layer

Signed-off-by: raver119 <raver119@gmail.com>

* avx2 tweak

Signed-off-by: raver119 <raver119@gmail.com>

* allow static for apple

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* exclude mkldnn in one more place

Signed-off-by: raver119 <raver119@gmail.com>

* exclude mkldnn in one more place

Signed-off-by: raver119 <raver119@gmail.com>

* restore OPENBLAS_PATH use

Signed-off-by: raver119 <raver119@gmail.com>

* add runtime check for avx/avx2 support

Signed-off-by: raver119 <raver119@gmail.com>

* convolution_auto

Signed-off-by: raver119 <raver119@gmail.com>

* Add logic for helper argument

* minor test fix

Signed-off-by: raver119 <raver119@gmail.com>

* few tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* few tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* skip OpTracker props for non-x86 builds

Signed-off-by: raver119 <raver119@gmail.com>

* linux arm isn't x86 :)

Signed-off-by: raver119 <raver119@gmail.com>

* avx-512

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA presets fix

Signed-off-by: raver119 <raver119@gmail.com>

* BUILD_PIC

Signed-off-by: raver119 <raver119@gmail.com>

* prefetchw for avx2

Signed-off-by: raver119 <raver119@gmail.com>

* BUILD_PIC again

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-09-11 21:50:28 +03:00 committed by GitHub
parent ffae024cda
commit 98e2814879
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
105 changed files with 3938 additions and 1494 deletions

View File

@ -8,12 +8,19 @@ set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS OFF)
option(BUILD_TESTS "Build tests" OFF) option(BUILD_TESTS "Build tests" OFF)
set(X86_BUILD false)
if (NOT IOS_BUILD AND NOT ANDROID_BUILD AND NOT ${ARCH} MATCHES "power*" AND NOT ${ARCH} MATCHES "arm*")
set(X86_BUILD true)
endif()
# -fsanitize=address # -fsanitize=address
# -fsanitize=leak # -fsanitize=leak
if (APPLE) if (APPLE)
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D__APPLE_OS__=true -D_RELEASE=true") set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true")
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D__APPLE_OS__=true") set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D__APPLE_OS__=true")
elseif(WIN32) elseif(WIN32)
set(X86_BUILD true)
if (NOT CUDA_BLAS) if (NOT CUDA_BLAS)
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D_RELEASE=true") set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D_RELEASE=true")
set(CMAKE_CXX_FLAGS_DEBUG " -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2") set(CMAKE_CXX_FLAGS_DEBUG " -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2")
@ -32,6 +39,7 @@ endif()
if(NATIVE) if(NATIVE)
IF(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*") IF(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
set(X86_BUILD false)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mcpu=native") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mcpu=native")
ELSE() ELSE()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native")
@ -39,21 +47,101 @@ if(NATIVE)
endif() endif()
if(NOT CUDA_BLAS) if(NOT CUDA_BLAS)
if (NOT "${MKLDNN_PATH}" STREQUAL "") # we need this definition to avoid global memory use within mkldnn
set(HAVE_MKLDNN 1) add_definitions(-DMKLDNN_ENABLE_CONCURRENT_EXEC=true)
include_directories(${MKLDNN_PATH}/include/)
link_directories(${MKLDNN_PATH} ${MKLDNN_PATH}/lib/) # there's a chance, we have no BLAS provided externally
IF(${CMAKE_SYSTEM_NAME} MATCHES "Linux") if ("${OPENBLAS_PATH}" STREQUAL "")
set(MKLDNN_LIBRARIES mkldnn mklml_intel) #we don't want static OpenBLAS on Apple
else() set(BLA_STATIC ON)
set(MKLDNN_LIBRARIES mkldnn mklml) if (NOT APPLE)
set(BLA_VENDOR "OpenBLAS")
endif() endif()
elseif (NOT "${OPENBLAS_PATH}" STREQUAL "")
# look around for system blas instead
find_package(BLAS REQUIRED)
if (BLAS_FOUND)
message("Original library: ${BLAS_LIBRARIES}")
# workaround for for cmake being unable to find static blas library
SET(_TMP_B "")
if (APPLE)
string(REGEX REPLACE "\\.dylib$" ".lib" _TMP_B "${BLAS_LIBRARIES}")
elseif (WIN32)
string(REGEX REPLACE "\\.dll" ".lib" _TMP_B "${BLAS_LIBRARIES}")
else()
string(REGEX REPLACE "\\.so$" ".a" _TMP_B "${BLAS_LIBRARIES}")
endif()
set(BLAS_LIBRARIES "${_TMP_B}")
message("Found external BLAS implementation: ${BLAS_LIBRARIES} ")
add_definitions(-D__EXTERNAL_BLAS__=true)
endif()
else()
# if we have externally provided OPENBLAS_PATH - let's use it
set(HAVE_OPENBLAS 1) set(HAVE_OPENBLAS 1)
include_directories(${OPENBLAS_PATH}/include/) include_directories(${OPENBLAS_PATH}/include/)
link_directories(${OPENBLAS_PATH} ${OPENBLAS_PATH}/lib/) link_directories(${OPENBLAS_PATH} ${OPENBLAS_PATH}/lib/)
set(OPENBLAS_LIBRARIES openblas) set(OPENBLAS_LIBRARIES openblas)
endif() endif()
# building cpu_features
if (X86_BUILD)
add_definitions(-DCPU_FEATURES=true)
set(BUILD_PIC "ON" CACHE STRING "Hack to enforce fPIC mode" FORCE)
configure_file(./CMakeLists.txt.cpu_features.in cpu_features-download/CMakeLists.txt)
message("CMAKE_COMMAND: ${CMAKE_COMMAND}")
execute_process(COMMAND ${CMAKE_COMMAND} -DBUILD_PIC=ON -G "${CMAKE_GENERATOR}" .
RESULT_VARIABLE result
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/cpu_features-download )
if(result)
message(FATAL_ERROR "CMake step for cpu_features failed: ${result}")
endif()
execute_process(COMMAND ${CMAKE_COMMAND} --build .
RESULT_VARIABLE result
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/cpu_features-download )
if(result)
message(FATAL_ERROR "Build step for cpu_features failed: ${result}")
endif()
add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/cpu_features-src
${CMAKE_CURRENT_BINARY_DIR}/cpu_features-build
EXCLUDE_FROM_ALL)
set(CPUF_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cpu_features-src)
include_directories(${CPUF_SOURCE_DIR}/include)
set(CPU_FEATURES cpu_features)
endif()
# new mkl-dnn entry
if (${HELPERS_mkldnn})
message("Going to pull & build mkldnn")
set(HAVE_MKLDNN 1)
set(MKLDNN_LIBRARY_TYPE "STATIC" CACHE STRING "Hack to enforce static mode" FORCE)
configure_file(./CMakeLists.txt.mkldnn.in mkldnn-download/CMakeLists.txt)
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
RESULT_VARIABLE result
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-download )
if(result)
message(FATAL_ERROR "CMake step for mkldnn failed: ${result}")
endif()
execute_process(COMMAND ${CMAKE_COMMAND} --build .
RESULT_VARIABLE result
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-download )
if(result)
message(FATAL_ERROR "Build step for mkldnn failed: ${result}")
endif()
add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src
${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build
EXCLUDE_FROM_ALL)
set(mkldnn_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build)
set(mkldnn_EXT_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src)
set(MKLDNN_PATH "${mkldnn_SOURCE_DIR}")
include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_EXT_DIR}/include ${mkldnn_SOURCE_DIR})
set(MKLDNN mkldnn)
endif()
endif() endif()
# Download and unpack flatbuffers at configure time # Download and unpack flatbuffers at configure time

View File

@ -0,0 +1,16 @@
cmake_minimum_required(VERSION 2.8.2)
project(mkldnn-download NONE)
include(ExternalProject)
ExternalProject_Add(mkldnn
GIT_REPOSITORY https://github.com/google/cpu_features.git
GIT_TAG v0.4.1
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/cpu_features-src"
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/cpu_features-build"
CONFIGURE_COMMAND ""
CMAKE_ARGS "-DBUILD_PIC=ON"
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)

View File

@ -5,11 +5,11 @@ project(mkldnn-download NONE)
include(ExternalProject) include(ExternalProject)
ExternalProject_Add(mkldnn ExternalProject_Add(mkldnn
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
GIT_TAG v0.18.1 GIT_TAG v1.0.2
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src" SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
CONFIGURE_COMMAND "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src/scripts/prepare_mkl.sh" CONFIGURE_COMMAND ""
CMAKE_ARGS -DMKLDNN_USE_MKL=ML -G \"Unix Makefiles\" -DMKLDNN_LIBRARY_TYPE=STATIC CMAKE_ARGS -DMKLDNN_USE_MKL=ML -DMKLDNN_LIBRARY_TYPE=STATIC -G \"Unix Makefiles\"
BUILD_COMMAND "" BUILD_COMMAND ""
INSTALL_COMMAND "" INSTALL_COMMAND ""
TEST_COMMAND "" TEST_COMMAND ""

View File

@ -78,14 +78,24 @@ IF(${ARCH} MATCHES "arm*")
ELSEIF(${ARCH} MATCHES "power*") ELSEIF(${ARCH} MATCHES "power*")
set(ARCH_TUNE "-mcpu=${ARCH} -mtune=${ARCH} -D__POWER") set(ARCH_TUNE "-mcpu=${ARCH} -mtune=${ARCH} -D__POWER")
ELSEIF(${EXTENSION} MATCHES "avx2") ELSEIF(${EXTENSION} MATCHES "avx2")
set(ARCH_TUNE "-march=${ARCH} -mtune=${ARCH} -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -D__F16C__=true") message("Building AVX2 binary...")
set(ARCH_TUNE "-mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mprefetchwt1 -DSD_F16C=true -DF_AVX2=true")
ELSE() ELSE()
if ("${ARCH}" STREQUAL "x86-64") if ("${ARCH}" STREQUAL "x86-64")
message("Building x86_64 binary...")
set(ARCH_TYPE "generic") set(ARCH_TYPE "generic")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DF_X64=true")
else() else()
set(ARCH_TYPE "${ARCH}") set(ARCH_TYPE "${ARCH}")
endif() endif()
IF(${EXTENSION} MATCHES "avx512")
message("Building AVX512 binary...")
# we need to set flag here, that we can use hardware f16 conversion + tell that cpu features should be tracked
message("Current CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512vl -mavx512bw -mavx512dq -mavx512cd -mbmi -mbmi2 -mprefetchwt1 -mclflushopt -mxsavec -mxsaves -DSD_F16C=true -DF_AVX512=true")
endif()
set(ARCH_TUNE "-march=${ARCH} -mtune=${ARCH_TYPE}") set(ARCH_TUNE "-march=${ARCH} -mtune=${ARCH_TYPE}")
ENDIF() ENDIF()
@ -299,19 +309,31 @@ elseif(CPU_BLAS)
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/*.cpp ../include/memory/*.h) file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/*.cpp ../include/memory/*.h)
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.h) file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.h)
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp) file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp ../include/ops/declarable/helpers/impl/*.cpp) file(GLOB_RECURSE CUSTOMOPS_GENERIC_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp ../include/ops/declarable/helpers/impl/*.cpp)
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h) file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h) file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/*.h) file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/*.h)
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h) file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
#if MKLDNN is enabled - we're building mkldnn-powered helpers
if (HAVE_MKLDNN)
file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h)
endif()
if (X86_BUILD)
#we disable platform optimizations for certains files
set_source_files_properties(cpu/NativeOps.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
endif()
message("CPU BLAS") message("CPU BLAS")
add_definitions(-D__CPUBLAS__=true) add_definitions(-D__CPUBLAS__=true)
add_library(nd4jobj OBJECT cpu/NativeOps.cpp cpu/GraphExecutioner.cpp add_library(nd4jobj OBJECT cpu/NativeOps.cpp cpu/GraphExecutioner.cpp
cpu/NativeOpExecutioner.cpp cpu/NDArray.cpp cpu/NDArrayFactory.cpp cpu/NativeOpExecutioner.cpp cpu/NDArray.cpp cpu/NDArrayFactory.cpp
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h ../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
Environment.cpp Environment.h ${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} Environment.cpp Environment.h ${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_HELPERS_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
${OPS_SOURCES} ${PERF_SOURCES}) ${OPS_SOURCES} ${PERF_SOURCES})
if(IOS) if(IOS)
add_library(${LIBND4J_NAME} STATIC $<TARGET_OBJECTS:nd4jobj>) add_library(${LIBND4J_NAME} STATIC $<TARGET_OBJECTS:nd4jobj>)
@ -320,12 +342,13 @@ elseif(CPU_BLAS)
add_library(${LIBND4J_NAME} SHARED $<TARGET_OBJECTS:nd4jobj>) add_library(${LIBND4J_NAME} SHARED $<TARGET_OBJECTS:nd4jobj>)
endif() endif()
target_link_libraries(${LIBND4J_NAME} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES}) # we're including {MKLDNN} here in case of building from sources. in future that'll replace {MKLDNN_LIBRARIES}. same applies to BLAS
target_link_libraries(${LIBND4J_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
if ("${LIBND4J_ALL_OPS}" AND "${LIBND4J_BUILD_MINIFIER}") if ("${LIBND4J_ALL_OPS}" AND "${LIBND4J_BUILD_MINIFIER}")
message(STATUS "Building minifier...") message(STATUS "Building minifier...")
add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp) add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
target_link_libraries(minifier ${LIBND4J_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES}) target_link_libraries(minifier ${LIBND4J_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES})
endif() endif()
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9) if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)

View File

@ -1760,6 +1760,13 @@ ND4J_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc);
ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc); ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc);
ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc); ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc);
ND4J_EXPORT int binaryLevel();
ND4J_EXPORT int optimalLevel();
ND4J_EXPORT bool isMinimalRequirementsMet();
ND4J_EXPORT bool isOptimalRequirementsMet();
} }
#endif //NATIVEOPERATIONS_NATIVEOPS_H #endif //NATIVEOPERATIONS_NATIVEOPS_H

View File

@ -76,6 +76,10 @@ bool experimentalSupport = false;
#include <performance/benchmarking/FullBenchmarkSuit.h> #include <performance/benchmarking/FullBenchmarkSuit.h>
#include <performance/benchmarking/LightBenchmarkSuit.h> #include <performance/benchmarking/LightBenchmarkSuit.h>
#ifdef CPU_FEATURES
#include <cpuinfo_x86.h>
#endif
using namespace nd4j; using namespace nd4j;
void setElementThreshold(int num) { void setElementThreshold(int num) {
@ -3167,6 +3171,75 @@ const char* lastErrorMessage() {
return nd4j::LaunchContext::defaultContext()->errorReference()->errorMessage(); return nd4j::LaunchContext::defaultContext()->errorReference()->errorMessage();
} }
int binaryLevel() {
#ifdef CPU_FEATURES
#if defined(F_X64)
return 1;
#elif defined (F_AVX2)
return 2;
#elif defined (F_AVX512)
return 3;
#else
return 0;
#endif
#else
return 0;
#endif
}
int optimalLevel() {
#ifdef CPU_FEATURES
auto features = cpu_features::GetX86Info().features;
if (features.avx && features.avx2 && features.avx512f && features.avx512vl && features.avx512bw && features.avx512dq && features.avx512cd)
return 3;
else if (features.avx && features.avx2)
return 2;
else
return 1;
#else
return 0;
#endif
}
bool isMinimalRequirementsMet() {
#ifdef CPU_FEATURES
auto features = cpu_features::GetX86Info().features;
#if defined(F_X64)
return true;
#elif defined (F_AVX2)
return features.avx && features.avx2;
#elif defined (F_AVX512)
// we're optimizing for skylake-avx512 features, so we'll check those out
return features.avx && features.avx2 && features.avx512f && features.avx512vl && features.avx512bw && features.avx512dq && features.avx512cd;
#else
return true;
#endif
#else
return true;
#endif
}
bool isOptimalRequirementsMet() {
#ifdef CPU_FEATURES
auto b = ::binaryLevel();
auto o = ::optimalLevel();
if (b == o)
return true;
else
return false;
#else
return true;
#endif
}
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, (void**, Nd4jLong**, void**, Nd4jLong**, int, int*, Nd4jLong**, Nd4jLong**), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, (void**, Nd4jLong**, void**, Nd4jLong**, int, int*, Nd4jLong**, Nd4jLong**), LIBND4J_TYPES);

View File

@ -3576,4 +3576,20 @@ int lastErrorCode() {
const char* lastErrorMessage() { const char* lastErrorMessage() {
return nd4j::LaunchContext::defaultContext()->errorReference()->errorMessage(); return nd4j::LaunchContext::defaultContext()->errorReference()->errorMessage();
}
int binaryLevel() {
return 0;
}
int optimalLevel() {
return 0;
}
bool isMinimalRequirementsMet() {
return true;
}
bool isOptimalRequirementsMet() {
return true;
} }

View File

@ -53,6 +53,7 @@ CLEAN="false"
MINIFIER="false" MINIFIER="false"
TESTS="false" TESTS="false"
VERBOSE="false" VERBOSE="false"
HELPER=
NAME= NAME=
while [[ $# > 0 ]] while [[ $# > 0 ]]
do do
@ -60,6 +61,10 @@ key="$1"
value="${2:-}" value="${2:-}"
#Build type (release/debug), packaging type, chip: cpu,cuda, lib type (static/dynamic) #Build type (release/debug), packaging type, chip: cpu,cuda, lib type (static/dynamic)
case $key in case $key in
-h|--helper)
HELPER="$value"
shift # past argument
;;
-o|-platform|--platform) -o|-platform|--platform)
OS="$value" OS="$value"
shift # past argument shift # past argument
@ -425,7 +430,7 @@ if [ "$PACKAGING" == "msi" ]; then
PACKAGING_ARG="-DPACKAGING=msi" PACKAGING_ARG="-DPACKAGING=msi"
fi fi
EXPERIMENTAL_ARG="no"; EXPERIMENTAL_ARG="";
MINIFIER_ARG="-DLIBND4J_BUILD_MINIFIER=false" MINIFIER_ARG="-DLIBND4J_BUILD_MINIFIER=false"
TESTS_ARG="-DBUILD_TESTS=OFF" TESTS_ARG="-DBUILD_TESTS=OFF"
NAME_ARG="-DLIBND4J_NAME=$NAME" NAME_ARG="-DLIBND4J_NAME=$NAME"
@ -461,16 +466,12 @@ if [ "$CHIP" == "cuda" ] && [ -n "$CHIP_VERSION" ]; then
esac esac
fi fi
[[ -z ${MKLDNN_PATH:-} ]] && MKLDNN_PATH=""
[[ -z ${OPENBLAS_PATH:-} ]] && OPENBLAS_PATH="" [[ -z ${OPENBLAS_PATH:-} ]] && OPENBLAS_PATH=""
if [[ -n "${BUILD_PATH:-}" ]]; then if [[ -n "${BUILD_PATH:-}" ]]; then
PREVIFS="$IFS" PREVIFS="$IFS"
IFS="$BUILD_PATH_SEPARATOR" IFS="$BUILD_PATH_SEPARATOR"
for P in $BUILD_PATH; do for P in $BUILD_PATH; do
if [[ -f "$P/include/mkldnn.h" ]]; then
MKLDNN_PATH="$P"
fi
if [[ -f "$P/include/openblas_config.h" ]]; then if [[ -f "$P/include/openblas_config.h" ]]; then
OPENBLAS_PATH="$P" OPENBLAS_PATH="$P"
fi fi
@ -478,18 +479,12 @@ if [[ -n "${BUILD_PATH:-}" ]]; then
IFS="$PREVIFS" IFS="$PREVIFS"
fi fi
if [[ ! -f "$MKLDNN_PATH/include/mkldnn.h" ]]; then
echo "Could not find MKL-DNN, please make sure to run the build with Maven or set the MKLDNN_PATH variable"
MKLDNN_PATH=""
fi
if [[ ! -f "$OPENBLAS_PATH/include/openblas_config.h" ]]; then if [[ ! -f "$OPENBLAS_PATH/include/openblas_config.h" ]]; then
echo "Could not find OpenBLAS, please make sure to run the build with Maven or set the OPENBLAS_PATH variable" echo "Could not find OpenBLAS, please make sure to run the build with Maven or set the OPENBLAS_PATH variable"
OPENBLAS_PATH="" OPENBLAS_PATH=""
fi fi
# replace any backslash with a slash # replace any backslash with a slash
MKLDNN_PATH="${MKLDNN_PATH//\\//}"
OPENBLAS_PATH="${OPENBLAS_PATH//\\//}" OPENBLAS_PATH="${OPENBLAS_PATH//\\//}"
mkbuilddir() { mkbuilddir() {
@ -501,6 +496,21 @@ mkbuilddir() {
cd "blasbuild/$CHIP" cd "blasbuild/$CHIP"
} }
if [ "$HELPER" == "" ]; then
echo "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
echo "!! !!"
echo "!! !!"
echo "!! !!"
echo "!! !!"
echo "!! WARNING! !!"
echo "!! No helper packages configured! !!"
echo "!! You can specify helper by using -h key. I.e. <-h mkldnn> !!"
echo "!! !!"
echo "!! !!"
echo "!! !!"
echo "!! !!"
echo "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
fi
echo PACKAGING = "${PACKAGING}" echo PACKAGING = "${PACKAGING}"
echo BUILD = "${BUILD}" echo BUILD = "${BUILD}"
@ -515,11 +525,11 @@ echo OPERATIONS = "${OPERATIONS_ARG}"
echo MINIFIER = "${MINIFIER_ARG}" echo MINIFIER = "${MINIFIER_ARG}"
echo TESTS = "${TESTS_ARG}" echo TESTS = "${TESTS_ARG}"
echo NAME = "${NAME_ARG}" echo NAME = "${NAME_ARG}"
echo MKLDNN_PATH = "$MKLDNN_PATH"
echo OPENBLAS_PATH = "$OPENBLAS_PATH" echo OPENBLAS_PATH = "$OPENBLAS_PATH"
echo HELPERS = "$HELPER"
mkbuilddir mkbuilddir
pwd pwd
eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DMKLDNN_PATH="$MKLDNN_PATH" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../.. eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" -DHELPERS_"$HELPER"=true "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../..
if [ "$PARALLEL" == "true" ]; then if [ "$PARALLEL" == "true" ]; then
MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ" MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ"
fi fi

View File

@ -223,7 +223,7 @@ DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory:
setCountersToZero(); setCountersToZero();
if(!lenInBytes == 0) { if(lenInBytes != 0) {
allocateBuffers(allocBoth); allocateBuffers(allocBoth);
writeSpecial(); writeSpecial();
} }

View File

@ -31,9 +31,10 @@
#endif #endif
#ifdef HAVE_MKLDNN #ifdef HAVE_MKLDNN
// FIXME: latest mkldnn doesn't ship mklml anymore?
// include CBLAS from MKL-DNN // include CBLAS from MKL-DNN
#include <mkl_cblas.h> //#include <mkl_cblas.h>
#define CBLAS_H //#define CBLAS_H
#endif #endif
#ifdef HAVE_OPENBLAS #ifdef HAVE_OPENBLAS

View File

@ -29,6 +29,10 @@
#include <cuda_device_runtime_api.h> #include <cuda_device_runtime_api.h>
#endif #endif
// used for MKLDNN etc
#if !defined(__STANDALONE_BUILD__)
#include "config.h"
#endif
#include <dll.h> #include <dll.h>
#include <memory> #include <memory>
@ -49,6 +53,9 @@ class ND4J_EXPORT LaunchContext {
static std::vector<std::shared_ptr<LaunchContext>> _contexts; static std::vector<std::shared_ptr<LaunchContext>> _contexts;
static std::mutex _mutex; static std::mutex _mutex;
// used for MKLDNN
void *_engine = nullptr;
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
#ifndef __JAVACPP_HACK__ #ifndef __JAVACPP_HACK__
@ -96,6 +103,8 @@ class ND4J_EXPORT LaunchContext {
_workspace = theWorkspace; _workspace = theWorkspace;
} }
void* engine();
int getDeviceID() const {return _deviceID;} int getDeviceID() const {return _deviceID;}
void setDeviceID(int deviceID) { _deviceID = deviceID; } void setDeviceID(int deviceID) { _deviceID = deviceID; }
sd::ErrorReference* errorReference(); sd::ErrorReference* errorReference();

View File

@ -29,10 +29,16 @@ nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers(); thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
#endif #endif
#ifdef HAVE_MKLDNN
#include <mkldnn.hpp>
#endif
namespace nd4j { namespace nd4j {
LaunchContext::~LaunchContext() { LaunchContext::~LaunchContext() {
#ifdef HAVE_MKLDNN
delete reinterpret_cast<mkldnn::engine*>(_engine);
#endif
} }
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>(); std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
@ -42,6 +48,10 @@ namespace nd4j {
// default constructor, just to make clang/ranlib happy // default constructor, just to make clang/ranlib happy
_workspace = nullptr; _workspace = nullptr;
_deviceID = 0; _deviceID = 0;
#ifdef HAVE_MKLDNN
_engine = new mkldnn::engine(mkldnn::engine::kind::cpu, 0);
#endif
} }
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
@ -73,4 +83,8 @@ namespace nd4j {
sd::ErrorReference* LaunchContext::errorReference() { sd::ErrorReference* LaunchContext::errorReference() {
return contextBuffers.errorReference(); return contextBuffers.errorReference();
} }
void* LaunchContext::engine() {
return _engine;
}
} }

View File

@ -169,4 +169,8 @@ LaunchContext::LaunchContext() {
sd::ErrorReference* LaunchContext::errorReference() { sd::ErrorReference* LaunchContext::errorReference() {
return contextBuffers.errorReference(); return contextBuffers.errorReference();
} }
void* LaunchContext::engine() {
return _engine;
}
} }

View File

@ -28,10 +28,6 @@
#include <graph/ContextPrototype.h> #include <graph/ContextPrototype.h>
#include <memory/Workspace.h> #include <memory/Workspace.h>
#ifdef HAVE_MKLDNN
#include <MKLDNNStream.h>
#endif
// CUDA-specific includes // CUDA-specific includes
#ifdef __CUDACC__ #ifdef __CUDACC__
@ -61,11 +57,6 @@ namespace nd4j {
LaunchContext* _context = nullptr; LaunchContext* _context = nullptr;
std::vector<nd4j::DataType> _dataTypes; std::vector<nd4j::DataType> _dataTypes;
#ifdef HAVE_MKLDNN
std::vector<nd4j::MKLDNNStream> _mkldnnStreams;
#else
std::vector<Nd4jLong> _mkldnnStreams;
#endif
std::vector<NDArray*> _fastpath_in; std::vector<NDArray*> _fastpath_in;
std::vector<NDArray*> _fastpath_out; std::vector<NDArray*> _fastpath_out;
@ -122,9 +113,6 @@ namespace nd4j {
int getBranch(); int getBranch();
void setBranch(int branch); void setBranch(int branch);
#ifdef HAVE_MKLDNN
std::vector<nd4j::MKLDNNStream>& getMKLDNNStreams() { return _mkldnnStreams; }
#endif
/** /**
* *
* @return * @return

View File

@ -98,9 +98,6 @@ namespace nd4j {
this->_inputs.clear(); this->_inputs.clear();
this->_fastpath_in.clear(); this->_fastpath_in.clear();
this->_fastpath_out.clear(); this->_fastpath_out.clear();
#ifdef HAVE_MKLDNN
this->_mkldnnStreams.clear();
#endif
for (auto v:_handles) for (auto v:_handles)
delete v; delete v;

View File

@ -21,12 +21,11 @@
#ifndef LIBND4J_MKLDNNSTREAM_H #ifndef LIBND4J_MKLDNNSTREAM_H
#define LIBND4J_MKLDNNSTREAM_H #define LIBND4J_MKLDNNSTREAM_H
#ifndef __STANDALONE_BUILD__ #if !defined(__STANDALONE_BUILD__)
#include "config.h" #include "config.h"
#endif #endif
#ifdef HAVE_MKLDNN #if defined(HAVE_MKLDNN)
#include <mkldnn.hpp>
namespace nd4j { namespace nd4j {
class MKLDNNStream { class MKLDNNStream {
@ -38,26 +37,24 @@ namespace nd4j {
std::vector<float> _floatArguments; std::vector<float> _floatArguments;
std::vector<int> _intArguments; std::vector<int> _intArguments;
mkldnn::engine _engine = mkldnn::engine(mkldnn::engine::cpu, 0);
std::vector<mkldnn::memory> _memory;
std::vector<mkldnn::primitive> _operations;
public: public:
template <typename X, typename Y> template <typename X, typename Y>
static bool isSupported() { static bool isSupported() {
// FIXME: strict float support doesn't work anymore
return typeid(X) == typeid(float) && typeid(Y) == typeid(float); return typeid(X) == typeid(float) && typeid(Y) == typeid(float);
} }
static bool isSupported(const std::vector<const NDArray*> &arrays) { static bool isSupported(const std::vector<const NDArray*> &arrays) {
for (auto i = arrays.begin(); i != arrays.end(); i++) { // FIXME: strict float support doesn't work anymore
if (*i != nullptr && (*i)->dataType() != nd4j::DataType::FLOAT32) { for (auto v:arrays) {
if (v != nullptr && v->dataType() != nd4j::DataType::FLOAT32) {
return false; return false;
} }
} }
return true; return true;
} }
MKLDNNStream(const std::string &opName) : _opName(opName) { } explicit MKLDNNStream(const std::string &opName) : _opName(opName) { }
bool checkAndReset(const std::vector<const NDArray*> &inputs, const std::vector<const NDArray*> &outputs, bool checkAndReset(const std::vector<const NDArray*> &inputs, const std::vector<const NDArray*> &outputs,
const std::vector<float> &floatArguments, const std::vector<int> &intArguments) { const std::vector<float> &floatArguments, const std::vector<int> &intArguments) {
@ -66,30 +63,10 @@ namespace nd4j {
_outputs = outputs; _outputs = outputs;
_floatArguments = floatArguments; _floatArguments = floatArguments;
_intArguments = intArguments; _intArguments = intArguments;
_operations.clear();
_memory.clear();
return true; return true;
} }
return false; return false;
} }
const mkldnn::engine &getEngine() { return _engine; }
void setEngine(const mkldnn::engine &engine) { _engine = engine; }
const std::vector<mkldnn::memory> &getMemory() { return _memory; }
void setMemory(const std::vector<mkldnn::memory> &memory) { _memory = memory; }
void addMemory(const mkldnn::memory &memory) { _memory.push_back(memory); }
const std::vector<mkldnn::primitive> &getOperations() { return _operations; }
void setOperations(const std::vector<mkldnn::primitive> &operations) { _operations = operations; }
void addOperation(const mkldnn::primitive &operation) { _operations.push_back(operation); }
bool submitAndWait(mkldnn::stream::kind kind = mkldnn::stream::kind::eager) {
nd4j_debug("Executing %s with MKL-DNN\n", _opName.c_str());
// need to create a new one because already executed streams become unusable
mkldnn::stream stream(kind);
return stream.submit(_operations).wait();
}
}; };
} }
#endif #endif

View File

@ -21,6 +21,8 @@
#include <helpers/OpTracker.h> #include <helpers/OpTracker.h>
#include <sstream> #include <sstream>
#include <helpers/logger.h> #include <helpers/logger.h>
#include <NativeOps.h>
using namespace nd4j::ops; using namespace nd4j::ops;
using namespace nd4j::graph; using namespace nd4j::graph;
@ -35,6 +37,31 @@ namespace nd4j {
} }
void OpTracker::storeOperation(nd4j::graph::OpType opType, const OpDescriptor& descriptor) { void OpTracker::storeOperation(nd4j::graph::OpType opType, const OpDescriptor& descriptor) {
// check out CPU features
if (!::isMinimalRequirementsMet()) {
auto binaryLevel = ::binaryLevel();
auto optimalLevel = ::optimalLevel();
switch (binaryLevel) {
case 3: {
nd4j_printf("libnd4j binary was built with AVX512 support, but current CPU doesn't have this instruction set. Exiting now...","");
}
break;
case 2: {
nd4j_printf("libnd4j binary was built with AVX/AVX2 support, but current CPU doesn't have this instruction set. Exiting now...","");
}
break;
default: {
nd4j_printf("Unknown binary validation error. Exiting now...","");
}
break;
}
// we're exiting now
exit(119);
}
//
if (_map.count(opType) < 1) { if (_map.count(opType) < 1) {
std::vector<OpDescriptor> vec; std::vector<OpDescriptor> vec;
_map[opType] = vec; _map[opType] = vec;

View File

@ -4417,8 +4417,10 @@ INLINEDEF void calcOffsets(const Nd4jLong* shapeInfo, Nd4jLong* offsets, const c
if(order == shape::order(shapeInfo) || e == 1) { // e==1 means common vector if(order == shape::order(shapeInfo) || e == 1) { // e==1 means common vector
e = 1; e = 1;
Nd4jLong len = shape::length(shapeInfo); Nd4jLong len = shape::length(shapeInfo);
while(e < len) while(e < len) {
offsets[e++] = offsets[e - 1] + ews; offsets[e] = offsets[e - 1] + ews;
e++;
}
return; return;
} }
} }
@ -4464,8 +4466,10 @@ INLINEDEF void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong
if(shape[j] == 1) { --j; continue; } // ignore dimensions equal to unity if(shape[j] == 1) { --j; continue; } // ignore dimensions equal to unity
if(j == rankMinusOne) { // last dimension if(j == rankMinusOne) { // last dimension
for(int l = 1; l < shape[j]; ++l) for(int l = 1; l < shape[j]; ++l) {
offsets[i++] = offsets[i - 1] + strides[j]; offsets[i] = offsets[i - 1] + strides[j];
i++;
}
--j; --j;
} }
else if(idx[j] < shape[j] - 1) { else if(idx[j] < shape[j] - 1) {
@ -4489,8 +4493,10 @@ INLINEDEF void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong
if(shape[j] == 1) { ++j; continue; } // ignore dimensions equal to unity if(shape[j] == 1) { ++j; continue; } // ignore dimensions equal to unity
if(j == 0) { // last dimension if(j == 0) { // last dimension
for(int l = 1; l < shape[j]; ++l) for(int l = 1; l < shape[j]; ++l) {
offsets[i++] = offsets[i - 1] + strides[j]; offsets[i] = offsets[i - 1] + strides[j];
i++;
}
++j; ++j;
} }
else if(idx[j] < shape[j] - 1) { else if(idx[j] < shape[j] - 1) {

View File

@ -32,7 +32,7 @@ namespace nd4j {
OpDescriptor * _descriptor; OpDescriptor * _descriptor;
bool prepareOutputs(Context& block); bool prepareOutputs(Context& block);
virtual Nd4jStatus validateAndExecute(Context& block) = 0; Nd4jStatus validateAndExecute(Context& block) override = 0;
public: public:
BooleanOp(const char *name, int numInputs, bool scalar); BooleanOp(const char *name, int numInputs, bool scalar);
~BooleanOp(); ~BooleanOp();

View File

@ -30,7 +30,7 @@ namespace nd4j {
namespace ops { namespace ops {
class ND4J_EXPORT BroadcastableOp : public DeclarableCustomOp{ class ND4J_EXPORT BroadcastableOp : public DeclarableCustomOp{
protected: protected:
virtual Nd4jStatus validateAndExecute(Context& block) = 0; Nd4jStatus validateAndExecute(Context& block) override = 0;
public: public:
BroadcastableOp(const char *name, int numTArgs, int numIArgs); BroadcastableOp(const char *name, int numTArgs, int numIArgs);
~BroadcastableOp(); ~BroadcastableOp();

View File

@ -30,12 +30,12 @@ namespace nd4j {
/** /**
* This method executes this Op * This method executes this Op
*/ */
virtual Nd4jStatus validateAndExecute(Context& block) = 0; Nd4jStatus validateAndExecute(Context& block) override = 0;
public: public:
DeclarableCustomOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs); DeclarableCustomOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs);
~DeclarableCustomOp(); ~DeclarableCustomOp();
virtual ShapeList* calculateOutputShape(ShapeList* inputShapes, nd4j::graph::Context& block) = 0; ShapeList* calculateOutputShape(ShapeList* inputShapes, nd4j::graph::Context& block) override = 0;
}; };
} }
} }

View File

@ -32,7 +32,7 @@ namespace nd4j {
namespace ops { namespace ops {
class ND4J_EXPORT DeclarableListOp : public nd4j::ops::DeclarableOp { class ND4J_EXPORT DeclarableListOp : public nd4j::ops::DeclarableOp {
protected: protected:
virtual Nd4jStatus validateAndExecute(Context& block) = 0; Nd4jStatus validateAndExecute(Context& block) override = 0;
nd4j::NDArray* getZ(Context& block, int inputId); nd4j::NDArray* getZ(Context& block, int inputId);
void setupResult(NDArray* array, Context& block); void setupResult(NDArray* array, Context& block);

View File

@ -30,12 +30,12 @@ namespace nd4j {
/** /**
* This method executes this Op * This method executes this Op
*/ */
virtual Nd4jStatus validateAndExecute(Context& block) = 0; Nd4jStatus validateAndExecute(Context& block) override = 0;
public: public:
DeclarableReductionOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs); DeclarableReductionOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs);
~DeclarableReductionOp(); ~DeclarableReductionOp();
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
}; };
} }
} }

View File

@ -30,13 +30,13 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyBroadcastBoolOp : public LegacyOp { class ND4J_EXPORT LegacyBroadcastBoolOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context& block); Nd4jStatus validateAndExecute(Context& block) override ;
public: public:
LegacyBroadcastBoolOp(); LegacyBroadcastBoolOp();
LegacyBroadcastBoolOp(int opNum); LegacyBroadcastBoolOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -30,13 +30,13 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyBroadcastOp : public LegacyOp { class ND4J_EXPORT LegacyBroadcastOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context& block); Nd4jStatus validateAndExecute(Context& block) override;
public: public:
LegacyBroadcastOp(); LegacyBroadcastOp();
LegacyBroadcastOp(int opNum); LegacyBroadcastOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -32,13 +32,13 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyIndexReduceOp : public LegacyOp { class ND4J_EXPORT LegacyIndexReduceOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context& block); Nd4jStatus validateAndExecute(Context& block) override;
public: public:
LegacyIndexReduceOp(); LegacyIndexReduceOp();
LegacyIndexReduceOp(int opNum); LegacyIndexReduceOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -41,13 +41,13 @@ namespace nd4j {
int _numInputs = 0; int _numInputs = 0;
// All Op classes provide own specific implementation for this method // All Op classes provide own specific implementation for this method
virtual Nd4jStatus validateAndExecute(Context& block) = 0; Nd4jStatus validateAndExecute(Context& block) override = 0;
public: public:
LegacyOp(int numInputs); LegacyOp(int numInputs);
LegacyOp(int numInputs, int opNum); LegacyOp(int numInputs, int opNum);
// All Op classes provide own specific implementation for this method // All Op classes provide own specific implementation for this method
virtual ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) = 0; ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override = 0;
virtual LegacyOp* clone() = 0; virtual LegacyOp* clone() = 0;
}; };
} }

View File

@ -30,13 +30,13 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyPairwiseTransformBoolOp: public LegacyOp { class ND4J_EXPORT LegacyPairwiseTransformBoolOp: public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context& block); Nd4jStatus validateAndExecute(Context& block) override;
public: public:
LegacyPairwiseTransformBoolOp(); LegacyPairwiseTransformBoolOp();
LegacyPairwiseTransformBoolOp(int opNum); LegacyPairwiseTransformBoolOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -30,13 +30,13 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyPairwiseTransformOp: public LegacyOp { class ND4J_EXPORT LegacyPairwiseTransformOp: public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context& block); Nd4jStatus validateAndExecute(Context& block) override;
public: public:
LegacyPairwiseTransformOp(); LegacyPairwiseTransformOp();
LegacyPairwiseTransformOp(int opNum); LegacyPairwiseTransformOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -32,7 +32,7 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyRandomOp : public LegacyOp { class ND4J_EXPORT LegacyRandomOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context& block); Nd4jStatus validateAndExecute(Context& block) override;
public: public:
LegacyRandomOp(); LegacyRandomOp();
LegacyRandomOp(int opNum); LegacyRandomOp(int opNum);
@ -43,10 +43,10 @@ namespace nd4j {
nd4j::ResultSet* execute(nd4j::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace = false); nd4j::ResultSet* execute(nd4j::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace = false);
nd4j::ResultSet* execute(nd4j::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace = false); nd4j::ResultSet* execute(nd4j::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace = false);
Nd4jStatus execute(Context* block); Nd4jStatus execute(Context* block) override;
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -30,13 +30,13 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyReduce3Op : public LegacyOp { class ND4J_EXPORT LegacyReduce3Op : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context& block); Nd4jStatus validateAndExecute(Context& block) override;
public: public:
LegacyReduce3Op(); LegacyReduce3Op();
LegacyReduce3Op(int opNum); LegacyReduce3Op(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -27,13 +27,13 @@ namespace nd4j {
namespace ops { namespace ops {
class ND4J_EXPORT LegacyReduceBoolOp : public LegacyOp { class ND4J_EXPORT LegacyReduceBoolOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context& block); Nd4jStatus validateAndExecute(Context& block) override;
public: public:
LegacyReduceBoolOp(); LegacyReduceBoolOp();
LegacyReduceBoolOp(int opNum); LegacyReduceBoolOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -27,13 +27,13 @@ namespace nd4j {
namespace ops { namespace ops {
class ND4J_EXPORT LegacyReduceFloatOp : public LegacyOp { class ND4J_EXPORT LegacyReduceFloatOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context& block); Nd4jStatus validateAndExecute(Context& block) override;
public: public:
LegacyReduceFloatOp(); LegacyReduceFloatOp();
LegacyReduceFloatOp(int opNum); LegacyReduceFloatOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -27,13 +27,13 @@ namespace nd4j {
namespace ops { namespace ops {
class ND4J_EXPORT LegacyReduceLongOp : public LegacyOp { class ND4J_EXPORT LegacyReduceLongOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context& block); Nd4jStatus validateAndExecute(Context& block) override;
public: public:
LegacyReduceLongOp(); LegacyReduceLongOp();
LegacyReduceLongOp(int opNum); LegacyReduceLongOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -27,13 +27,13 @@ namespace nd4j {
namespace ops { namespace ops {
class ND4J_EXPORT LegacyReduceSameOp: public LegacyOp { class ND4J_EXPORT LegacyReduceSameOp: public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context& block); Nd4jStatus validateAndExecute(Context& block) override;
public: public:
LegacyReduceSameOp(); LegacyReduceSameOp();
LegacyReduceSameOp(int opNum); LegacyReduceSameOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -30,15 +30,15 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyScalarBoolOp : public LegacyOp { class ND4J_EXPORT LegacyScalarBoolOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context& block); Nd4jStatus validateAndExecute(Context& block) override;
public: public:
LegacyScalarBoolOp(); LegacyScalarBoolOp();
LegacyScalarBoolOp(int opNum); LegacyScalarBoolOp(int opNum);
LegacyScalarBoolOp(int opNum, NDArray &scalar); LegacyScalarBoolOp(int opNum, NDArray &scalar);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -30,15 +30,15 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyScalarOp : public LegacyOp { class ND4J_EXPORT LegacyScalarOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context& block); Nd4jStatus validateAndExecute(Context& block) override;
public: public:
LegacyScalarOp(); LegacyScalarOp();
LegacyScalarOp(int opNum); LegacyScalarOp(int opNum);
LegacyScalarOp(int opNum, NDArray &scalar); LegacyScalarOp(int opNum, NDArray &scalar);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -30,13 +30,13 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyStatsOp : public LegacyOp { class ND4J_EXPORT LegacyStatsOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context &block); Nd4jStatus validateAndExecute(Context &block) override;
public: public:
LegacyStatsOp(); LegacyStatsOp();
LegacyStatsOp(int opNum); LegacyStatsOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -31,13 +31,13 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyTransformAnyOp : public LegacyOp { class ND4J_EXPORT LegacyTransformAnyOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context &block); Nd4jStatus validateAndExecute(Context &block) override;
public: public:
LegacyTransformAnyOp(); LegacyTransformAnyOp();
LegacyTransformAnyOp(int opNum); LegacyTransformAnyOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -32,13 +32,13 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyTransformBoolOp : public LegacyOp { class ND4J_EXPORT LegacyTransformBoolOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context &block); Nd4jStatus validateAndExecute(Context &block) override;
public: public:
LegacyTransformBoolOp(); LegacyTransformBoolOp();
LegacyTransformBoolOp(int opNum); LegacyTransformBoolOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -31,13 +31,13 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyTransformFloatOp : public LegacyOp { class ND4J_EXPORT LegacyTransformFloatOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context &block); Nd4jStatus validateAndExecute(Context &block) override;
public: public:
LegacyTransformFloatOp(); LegacyTransformFloatOp();
LegacyTransformFloatOp(int opNum); LegacyTransformFloatOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -32,13 +32,13 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyTransformSameOp : public LegacyOp { class ND4J_EXPORT LegacyTransformSameOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context &block); Nd4jStatus validateAndExecute(Context &block) override;
public: public:
LegacyTransformSameOp(); LegacyTransformSameOp();
LegacyTransformSameOp(int opNum); LegacyTransformSameOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -32,13 +32,13 @@ namespace nd4j {
*/ */
class ND4J_EXPORT LegacyTransformStrictOp : public LegacyOp { class ND4J_EXPORT LegacyTransformStrictOp : public LegacyOp {
protected: protected:
Nd4jStatus validateAndExecute(Context &block); Nd4jStatus validateAndExecute(Context &block) override;
public: public:
LegacyTransformStrictOp(); LegacyTransformStrictOp();
LegacyTransformStrictOp(int opNum); LegacyTransformStrictOp(int opNum);
ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block); ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context &block) override;
virtual LegacyOp* clone(); LegacyOp* clone() override;
}; };
} }
} }

View File

@ -26,6 +26,7 @@
#include <map> #include <map>
#include <mutex> #include <mutex>
#include <ops/declarable/DeclarableOp.h> #include <ops/declarable/DeclarableOp.h>
#include <ops/declarable/PlatformHelper.h>
// handlers part // handlers part
#include <cstdlib> #include <cstdlib>
@ -59,10 +60,16 @@ namespace nd4j {
std::map<Nd4jLong, std::string> _msvc; std::map<Nd4jLong, std::string> _msvc;
// pointers to our operations
std::map<Nd4jLong, nd4j::ops::DeclarableOp*> _declarablesLD; std::map<Nd4jLong, nd4j::ops::DeclarableOp*> _declarablesLD;
std::map<std::string, nd4j::ops::DeclarableOp*> _declarablesD; std::map<std::string, nd4j::ops::DeclarableOp*> _declarablesD;
std::vector<nd4j::ops::DeclarableOp *> _uniqueD; std::vector<nd4j::ops::DeclarableOp *> _uniqueD;
// pointers to platform-specific helpers
std::map<Nd4jLong, nd4j::ops::platforms::PlatformHelper*> _helpersLH;
std::map<std::string, nd4j::ops::platforms::PlatformHelper*> _helpersH;
std::vector<nd4j::ops::platforms::PlatformHelper*> _uniqueH;
std::mutex _locker; std::mutex _locker;
std::string _opsList; std::string _opsList;
bool isInit = false; bool isInit = false;
@ -82,16 +89,22 @@ namespace nd4j {
const char * getAllCustomOperations(); const char * getAllCustomOperations();
/** /**
* This method registers operation * This method registers operation in our registry, so we can use them later
* *
* @param op * @param op
*/ */
bool registerOperation(const char* name, nd4j::ops::DeclarableOp* op); bool registerOperation(const char* name, nd4j::ops::DeclarableOp* op);
bool registerOperation(nd4j::ops::DeclarableOp *op); bool registerOperation(nd4j::ops::DeclarableOp *op);
void registerHelper(nd4j::ops::platforms::PlatformHelper* op);
bool hasHelper(Nd4jLong hash);
nd4j::ops::DeclarableOp* getOperation(const char *name); nd4j::ops::DeclarableOp* getOperation(const char *name);
nd4j::ops::DeclarableOp* getOperation(Nd4jLong hash); nd4j::ops::DeclarableOp* getOperation(Nd4jLong hash);
nd4j::ops::DeclarableOp* getOperation(std::string& name); nd4j::ops::DeclarableOp* getOperation(std::string &name);
nd4j::ops::platforms::PlatformHelper* getPlatformHelper(Nd4jLong hash);
std::vector<Nd4jLong> getAllHashes(); std::vector<Nd4jLong> getAllHashes();

View File

@ -0,0 +1,81 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SD_PLATFORMHELPER_H
#define SD_PLATFORMHELPER_H
#include <ShapeUtils.h>
#include <graph/Context.h>
#include <string>
#include <pointercast.h>
#include <dll.h>
namespace nd4j {
namespace ops {
namespace platforms {
/**
* This abstract class defines methods used by platform-specific helpers implementations
*/
class ND4J_EXPORT PlatformHelper {
protected:
// name of the operation this helper is built for
std::string _name;
// hash of the operation this helper is built for
Nd4jLong _hash;
public:
PlatformHelper(const char *name);
~PlatformHelper() = default;
std::string name();
Nd4jLong hash();
/**
* This method checks, if given helper can be used with given input/output/configuration options
*
* @param context
* @return
*/
virtual bool isUsable(graph::Context &context) = 0;
/**
* This method invokes helper. Typically this method replaces actual op execution
*
* @param context
* @return
*/
virtual Nd4jStatus invokeHelper(graph::Context &context) = 0;
/**
* Helper method, needed for compatibility with DeclarableOp macros
* @param ctx
* @param inputId
* @return
*/
nd4j::NDArray *getZ(graph::Context &ctx, int inputId);
};
}
}
}
#endif //SD_PLATFORMHELPER_H

View File

@ -28,54 +28,6 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
#ifdef HAVE_MKLDNN
using namespace mkldnn;
static void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
const Nd4jLong* shape = src->getShapeInfo();
Nd4jLong rank = shape[0];
Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
Nd4jLong dim2 = axis >= 2 ? 1 : 2;
Nd4jLong dim3 = axis >= 3 ? 2 : 3;
mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
auto type = mkldnn::memory::data_type::f32;
auto format = mkldnn::memory::format::nchw;
auto supposed_to_be_any_format = mkldnn::memory::format::nChw8c; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
*batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
user_src_md->data.format = mkldnn_blocked; // overrides format
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[0];
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[dim1];
user_src_md->data.layout_desc.blocking.strides[0][2] = rank > 2 ? src->stridesOf()[dim2] : 1;
user_src_md->data.layout_desc.blocking.strides[0][3] = rank > 3 ? src->stridesOf()[dim3] : 1;
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
*batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
user_diff_src_md->data.format = mkldnn_blocked; // overrides format
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[0];
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[dim1];
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
}
if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
*batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
user_dst_md->data.format = mkldnn_blocked; // overrides format
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[0];
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[dim1];
user_dst_md->data.layout_desc.blocking.strides[0][2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
user_dst_md->data.layout_desc.blocking.strides[0][3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
}
}
#endif
CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) { CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
@ -208,84 +160,6 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) {
for(int i = 1; i < block.width(); ++i) for(int i = 1; i < block.width(); ++i)
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_NEW op: types of all input arrays should be the same !"); REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_NEW op: types of all input arrays should be the same !");
#ifdef HAVE_MKLDNN
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) && numOfAxes == 1) {
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
if (streams.empty()) {
streams.push_back(MKLDNNStream("batchnorm_new"));
}
std::vector<Nd4jLong> shape({2, mean->lengthOf()});
NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext());
weights({0, 1, 0, 0}).assign(1.0f);
weights({1, 2, 0, 0}).assign(0.0f);
if (streams[0].checkAndReset({input, mean, variance, gamma, beta}, {output}, {(float)epsilon}, axes)) {
mkldnn_memory_desc_t empty;
mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty);
getMKLDNNMemoryDescBatchNorm(input, nullptr, output,
&batchnorm_src_md, nullptr, &batchnorm_dst_md,
&user_src_md, nullptr, &user_dst_md, axes[0]);
auto batchnorm_desc = batch_normalization_forward::desc(prop_kind::forward_inference, batchnorm_src_md, epsilon,
use_global_stats | (applyScale || applyOffset ? use_scale_shift : 0));
auto engine = streams[0].getEngine();
auto batchnorm_prim_desc = batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
auto user_src_memory = mkldnn::memory({user_src_md, engine}, input->buffer());
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output->buffer());
auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_primitive_desc(), mean->buffer());
auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_primitive_desc(), variance->buffer());
auto batchnorm_src_memory = user_src_memory;
streams[0].addMemory(user_src_memory);
if (mkldnn::memory::primitive_desc({batchnorm_src_md, engine})
!= user_src_memory.get_primitive_desc()) {
batchnorm_src_memory = mkldnn::memory({batchnorm_src_md, engine});
streams[0].addMemory(batchnorm_src_memory);
streams[0].addOperation(reorder(user_src_memory, batchnorm_src_memory));
}
auto batchnorm_dst_memory = user_dst_memory;
streams[0].addMemory(user_dst_memory);
if (mkldnn::memory::primitive_desc(batchnorm_prim_desc.dst_primitive_desc())
!= user_dst_memory.get_primitive_desc()) {
batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_primitive_desc());
streams[0].addMemory(batchnorm_dst_memory);
}
streams[0].addMemory(batchnorm_mean_memory);
streams[0].addMemory(batchnorm_variance_memory);
if (applyScale || applyOffset) {
auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_primitive_desc(), weights.buffer());
streams[0].addMemory(batchnorm_weights_memory);
streams[0].addOperation(batch_normalization_forward(batchnorm_prim_desc, (mkldnn::primitive::at)batchnorm_src_memory,
(mkldnn::primitive::at)batchnorm_mean_memory, (mkldnn::primitive::at)batchnorm_variance_memory, (mkldnn::primitive::at)batchnorm_weights_memory, batchnorm_dst_memory));
} else {
streams[0].addOperation(batch_normalization_forward(batchnorm_prim_desc, (mkldnn::primitive::at)batchnorm_src_memory,
(mkldnn::primitive::at)batchnorm_mean_memory, (mkldnn::primitive::at)batchnorm_variance_memory, batchnorm_dst_memory));
}
if (mkldnn::memory::primitive_desc(batchnorm_prim_desc.dst_primitive_desc())
!= user_dst_memory.get_primitive_desc()) {
streams[0].addOperation(reorder(batchnorm_dst_memory, user_dst_memory));
}
}
if (applyScale || applyOffset) {
if (gamma != nullptr) {
weights({0, 1, 0, 0}).assign(gamma);
}
if (beta != nullptr) {
weights({1, 2, 0, 0}).assign(beta);
}
}
streams[0].submitAndWait();
return Status::OK();
}
#endif
nd4j_debug("MKL-DNN is not used for batchnorm_new!\n", 0); nd4j_debug("MKL-DNN is not used for batchnorm_new!\n", 0);
// formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta

View File

@ -29,12 +29,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
#ifdef HAVE_MKLDNN
using namespace mkldnn;
#endif
CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
@ -70,83 +65,6 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
if(isSameMode) // SAME if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
#ifdef HAVE_MKLDNN
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output})) {
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
if (streams.empty()) {
streams.push_back(MKLDNNStream("conv3dnew"));
}
if (streams[0].checkAndReset({input, weights, bias}, {output}, {}, {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, isNCDHW})) {
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
ConvolutionUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, isNCDHW,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights, nullptr, bias, output,
&conv_src_md, nullptr, &conv_weights_md, nullptr, &conv_bias_md, &conv_dst_md,
&user_src_md, nullptr, &user_weights_md, nullptr, &user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r);
auto conv_desc = bias != nullptr
? convolution_forward::desc(prop_kind::forward,
convolution_direct, conv_src_md, conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
: convolution_forward::desc(prop_kind::forward,
convolution_direct, conv_src_md, conv_weights_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
auto engine = streams[0].getEngine();
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray*>(input)->buffer());
auto user_weights_memory = mkldnn::memory({user_weights_md, engine}, const_cast<NDArray*>(weights)->buffer());
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output->buffer());
auto conv_src_memory = user_src_memory;
streams[0].addMemory(user_src_memory);
if (mkldnn::memory::primitive_desc(conv_prim_desc.src_primitive_desc())
!= user_src_memory.get_primitive_desc()) {
conv_src_memory = mkldnn::memory(conv_prim_desc.src_primitive_desc());
streams[0].addMemory(conv_src_memory);
streams[0].addOperation(reorder(user_src_memory, conv_src_memory));
}
auto conv_weights_memory = user_weights_memory;
streams[0].addMemory(user_weights_memory);
if (mkldnn::memory::primitive_desc(conv_prim_desc.weights_primitive_desc())
!= user_weights_memory.get_primitive_desc()) {
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_primitive_desc());
streams[0].addMemory(conv_weights_memory);
streams[0].addOperation(reorder(user_weights_memory, conv_weights_memory));
}
auto conv_dst_memory = user_dst_memory;
streams[0].addMemory(user_dst_memory);
if (mkldnn::memory::primitive_desc(conv_prim_desc.dst_primitive_desc())
!= user_dst_memory.get_primitive_desc()) {
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_primitive_desc());
streams[0].addMemory(conv_dst_memory);
}
if (bias != nullptr) {
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_primitive_desc(), bias->buffer());
streams[0].addMemory(conv_bias_memory);
streams[0].addOperation(convolution_forward(conv_prim_desc, conv_src_memory, conv_weights_memory, conv_bias_memory, conv_dst_memory));
} else {
streams[0].addOperation(convolution_forward(conv_prim_desc, conv_src_memory, conv_weights_memory, conv_dst_memory));
}
if (mkldnn::memory::primitive_desc(conv_prim_desc.dst_primitive_desc())
!= user_dst_memory.get_primitive_desc()) {
streams[0].addOperation(reorder(conv_dst_memory, user_dst_memory));
}
}
streams[0].submitAndWait();
return Status::OK();
}
#endif
nd4j_debug("MKL-DNN is not used for conv3dnew!\n", 0); nd4j_debug("MKL-DNN is not used for conv3dnew!\n", 0);
std::vector<int> permutForOutput; std::vector<int> permutForOutput;
@ -297,151 +215,6 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
if(isSameMode) // SAME if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
#ifdef HAVE_MKLDNN
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB})) {
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
if (streams.empty()) {
streams.push_back(MKLDNNStream("conv3dnew_bp_weights"));
streams.push_back(MKLDNNStream("conv3dnew_bp_data"));
}
bool resetW = streams[0].checkAndReset({input, weights, bias, gradO}, {gradI, gradW, gradB}, {}, {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, isNDHWC});
bool resetI = streams[1].checkAndReset({input, weights, bias, gradO}, {gradI, gradW, gradB}, {}, {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, isNDHWC});
if (resetW || resetI) {
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
ConvolutionUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, isNDHWC,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights, gradW, gradB, gradO,
&conv_src_md, &conv_diff_src_md, &conv_weights_md, &conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
&user_src_md, &user_diff_src_md, &user_weights_md, &user_diff_weights_md, &user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r);
auto conv_desc = gradB != nullptr
? convolution_forward::desc(prop_kind::forward,
convolution_direct, conv_src_md, conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
: convolution_forward::desc(prop_kind::forward,
convolution_direct, conv_src_md, conv_weights_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, streams[0].getEngine());
if (gradW != nullptr) {
auto convW_desc = gradB != nullptr
? convolution_backward_weights::desc(
convolution_direct, conv_src_md, conv_diff_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
: convolution_backward_weights::desc(
convolution_direct, conv_src_md, conv_diff_weights_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
auto engine = streams[0].getEngine();
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc);
auto userW_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray*>(input)->buffer());
auto userW_weights_memory = mkldnn::memory({user_diff_weights_md, engine}, gradW->buffer());
auto userW_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray*>(gradO)->buffer());
auto convW_src_memory = userW_src_memory;
streams[0].addMemory(userW_src_memory);
if (mkldnn::memory::primitive_desc(convW_prim_desc.src_primitive_desc())
!= userW_src_memory.get_primitive_desc()) {
convW_src_memory = mkldnn::memory(convW_prim_desc.src_primitive_desc());
streams[0].addMemory(convW_src_memory);
streams[0].addOperation(reorder(userW_src_memory, convW_src_memory));
}
auto convW_weights_memory = userW_weights_memory;
streams[0].addMemory(userW_weights_memory);
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_weights_primitive_desc())
!= userW_weights_memory.get_primitive_desc()) {
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_primitive_desc());
streams[0].addMemory(convW_weights_memory);
}
auto convW_dst_memory = userW_dst_memory;
streams[0].addMemory(userW_dst_memory);
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_dst_primitive_desc())
!= userW_dst_memory.get_primitive_desc()) {
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_primitive_desc());
streams[0].addMemory(convW_dst_memory);
streams[0].addOperation(reorder(userW_dst_memory, convW_dst_memory));
}
if (gradB != nullptr) {
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_primitive_desc(), gradB->buffer());
streams[0].addMemory(convW_bias_memory);
streams[0].addOperation(convolution_backward_weights(convW_prim_desc, convW_src_memory, convW_dst_memory, convW_weights_memory, convW_bias_memory));
} else {
streams[0].addOperation(convolution_backward_weights(convW_prim_desc, convW_src_memory, convW_dst_memory, convW_weights_memory));
}
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_weights_primitive_desc())
!= userW_weights_memory.get_primitive_desc()) {
streams[0].addOperation(reorder(convW_weights_memory, userW_weights_memory));
}
}
if (gradI != nullptr) {
auto convI_desc =
convolution_backward_data::desc(
convolution_direct, conv_diff_src_md, conv_weights_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
auto engine = streams[1].getEngine();
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc);
auto userI_src_memory = mkldnn::memory({user_diff_src_md, engine}, gradI->buffer());
auto userI_weights_memory = mkldnn::memory({user_weights_md, engine}, const_cast<NDArray*>(weights)->buffer());
auto userI_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray*>(gradO)->buffer());
auto convI_src_memory = userI_src_memory;
streams[1].addMemory(userI_src_memory);
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_src_primitive_desc())
!= userI_src_memory.get_primitive_desc()) {
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_primitive_desc());
streams[1].addMemory(convI_src_memory);
}
auto convI_weights_memory = userI_weights_memory;
streams[1].addMemory(userI_weights_memory);
if (mkldnn::memory::primitive_desc(convI_prim_desc.weights_primitive_desc())
!= userI_weights_memory.get_primitive_desc()) {
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_primitive_desc());
streams[1].addMemory(convI_weights_memory);
streams[1].addOperation(reorder(userI_weights_memory, convI_weights_memory));
}
auto convI_dst_memory = userI_dst_memory;
streams[1].addMemory(userI_dst_memory);
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_dst_primitive_desc())
!= userI_dst_memory.get_primitive_desc()) {
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_primitive_desc());
streams[1].addMemory(convI_dst_memory);
streams[1].addOperation(reorder(userI_dst_memory, convI_dst_memory));
}
streams[1].addOperation(convolution_backward_data(convI_prim_desc, convI_dst_memory, convI_weights_memory, convI_src_memory));
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_src_primitive_desc())
!= userI_src_memory.get_primitive_desc()) {
streams[1].addOperation(reorder(convI_src_memory, userI_src_memory));
}
}
}
if (gradW != nullptr) {
streams[0].submitAndWait();
}
if (gradI != nullptr) {
streams[1].submitAndWait();
}
return Status::OK();
}
#endif
nd4j_debug("MKL-DNN is not used for conv3dnew_bp!\n", 0); nd4j_debug("MKL-DNN is not used for conv3dnew_bp!\n", 0);
std::vector<int> gradOaxesForDot; std::vector<int> gradOaxesForDot;

View File

@ -41,7 +41,6 @@ namespace nd4j {
REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn: Input rank of 4 expected, but got %i instead", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn: Input rank of 4 expected, but got %i instead", input->rankOf());
// FIXME: double?
double alpha = T_ARG(1); double alpha = T_ARG(1);
double beta = T_ARG(2); double beta = T_ARG(2);
double bias = T_ARG(0); double bias = T_ARG(0);

View File

@ -24,9 +24,6 @@
#include <NDArray.h> #include <NDArray.h>
#include <graph/Context.h> #include <graph/Context.h>
#ifdef HAVE_MKLDNN
#include <helpers/MKLDNNStream.h>
#endif
#include <execution/LaunchContext.h> #include <execution/LaunchContext.h>
namespace nd4j { namespace nd4j {
@ -197,44 +194,6 @@ namespace nd4j {
} }
#ifdef HAVE_MKLDNN
static void getMKLDNNMemoryDescConv2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
static void getMKLDNNMemoryDescConv3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
static void getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
static void getMKLDNNMemoryDescPool3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
#endif
static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
// static void conv2d(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs); // static void conv2d(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);

View File

@ -28,121 +28,6 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
#ifdef HAVE_MKLDNN
using namespace mkldnn;
void ConvolutionUtils::getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
mkldnn::memory::dims pool_src_tz = { bS, iC, iH, iW };
mkldnn::memory::dims pool_dst_tz = { bS, oC, oH, oW };
pool_strides = { sH, sW };
pool_kernel = { kH, kW };
pool_padding = { pH, pW };
pool_padding_r = { (oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
algorithm = poolingMode == 0 ? pooling_max
: extraParam0 == 0 ? pooling_avg_exclude_padding
: pooling_avg_include_padding;
auto type = mkldnn::memory::data_type::f32;
auto format = isNCHW ? mkldnn::memory::format::nchw : mkldnn::memory::format::nhwc;
auto supposed_to_be_any_format = mkldnn::memory::format::nChw8c; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCHW ? 0 : 0];
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCHW ? 1 : 3];
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCHW ? 2 : 1];
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCHW ? 3 : 2];
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
}
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCHW ? 0 : 0];
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCHW ? 1 : 3];
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCHW ? 2 : 1];
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCHW ? 3 : 2];
}
}
void ConvolutionUtils::getMKLDNNMemoryDescPool3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
mkldnn::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
mkldnn::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
pool_strides = { sD, sH, sW };
pool_kernel = { kD, kH, kW };
pool_padding = { pD, pH, pW };
pool_padding_r = { (oD - 1) * sD - iD + kD - pD,
(oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
algorithm = poolingMode == 0 ? pooling_max
: extraParam0 == 0 ? pooling_avg_exclude_padding
: pooling_avg_include_padding;
auto type = mkldnn::memory::data_type::f32;
auto format = isNCDHW ? mkldnn::memory::format::ncdhw : mkldnn::memory::format::ndhwc;
auto supposed_to_be_any_format = mkldnn::memory::format::nCdhw8c; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCDHW ? 0 : 0];
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCDHW ? 1 : 4];
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCDHW ? 2 : 1];
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCDHW ? 3 : 2];
user_src_md->data.layout_desc.blocking.strides[0][4] = src->stridesOf()[isNCDHW ? 4 : 3];
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
user_diff_src_md->data.layout_desc.blocking.strides[0][4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
}
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCDHW ? 0 : 0];
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCDHW ? 1 : 4];
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCDHW ? 2 : 1];
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCDHW ? 3 : 2];
user_dst_md->data.layout_desc.blocking.strides[0][4] = dst->stridesOf()[isNCDHW ? 4 : 3];
}
}
#endif
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
@ -348,174 +233,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescPool3d(
} }
#ifdef HAVE_MKLDNN
using namespace mkldnn;
void ConvolutionUtils::getMKLDNNMemoryDescConv2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r) {
mkldnn::memory::dims conv_src_tz = { bS, iC, iH, iW };
mkldnn::memory::dims conv_weights_tz = { oC, iC, kH, kW };
mkldnn::memory::dims conv_bias_tz = { oC };
mkldnn::memory::dims conv_dst_tz = { bS, oC, oH, oW };
conv_strides = { sH, sW };
conv_padding = { pH, pW };
conv_padding_r = { (oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
auto type = mkldnn::memory::data_type::f32;
auto format = isNCHW ? mkldnn::memory::format::nchw : mkldnn::memory::format::nhwc;
auto formatw = mkldnn::memory::format::hwio;
if (src != nullptr && conv_src_md != nullptr) {
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format::any);
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCHW ? 0 : 0];
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCHW ? 1 : 3];
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCHW ? 2 : 1];
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCHW ? 3 : 2];
}
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format::any);
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
}
if (weights != nullptr && conv_weights_md != nullptr) {
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format::any);
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
user_weights_md->data.format = mkldnn_blocked; // overrides "formatw = hwio"
user_weights_md->data.layout_desc.blocking.strides[0][0] = weights->stridesOf()[3];
user_weights_md->data.layout_desc.blocking.strides[0][1] = weights->stridesOf()[2];
user_weights_md->data.layout_desc.blocking.strides[0][2] = weights->stridesOf()[0];
user_weights_md->data.layout_desc.blocking.strides[0][3] = weights->stridesOf()[1];
}
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format::any);
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
user_diff_weights_md->data.format = mkldnn_blocked; // overrides "formatw = hwio"
user_diff_weights_md->data.layout_desc.blocking.strides[0][0] = diff_weights->stridesOf()[3];
user_diff_weights_md->data.layout_desc.blocking.strides[0][1] = diff_weights->stridesOf()[2];
user_diff_weights_md->data.layout_desc.blocking.strides[0][2] = diff_weights->stridesOf()[0];
user_diff_weights_md->data.layout_desc.blocking.strides[0][3] = diff_weights->stridesOf()[1];
}
if (bias != nullptr && conv_bias_md != nullptr) {
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format::any);
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format::x);
}
if (dst != nullptr && conv_dst_md != nullptr) {
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format::any);
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format);
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCHW ? 0 : 0];
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCHW ? 1 : 3];
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCHW ? 2 : 1];
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCHW ? 3 : 2];
}
}
void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r) {
mkldnn::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
mkldnn::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
mkldnn::memory::dims conv_bias_tz = { oC };
mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
conv_strides = { sD, sH, sW };
conv_padding = { pD, pH, pW };
conv_padding_r = { (oD - 1) * sD - iD + kD - pD,
(oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
auto type = mkldnn::memory::data_type::f32;
auto format = isNCDHW ? mkldnn::memory::format::ncdhw : mkldnn::memory::format::ndhwc;
auto formatw = mkldnn::memory::format::dhwio;
if (src != nullptr && conv_src_md != nullptr) {
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format::any);
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCDHW ? 0 : 0];
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCDHW ? 1 : 4];
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCDHW ? 2 : 1];
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCDHW ? 3 : 2];
user_src_md->data.layout_desc.blocking.strides[0][4] = src->stridesOf()[isNCDHW ? 4 : 3];
}
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format::any);
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
user_diff_src_md->data.layout_desc.blocking.strides[0][4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
}
if (weights != nullptr && conv_weights_md != nullptr) {
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format::any);
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
user_weights_md->data.format = mkldnn_blocked; // overrides "formatw = dhwio"
user_weights_md->data.layout_desc.blocking.strides[0][0] = weights->stridesOf()[4];
user_weights_md->data.layout_desc.blocking.strides[0][1] = weights->stridesOf()[3];
user_weights_md->data.layout_desc.blocking.strides[0][2] = weights->stridesOf()[0];
user_weights_md->data.layout_desc.blocking.strides[0][3] = weights->stridesOf()[1];
user_weights_md->data.layout_desc.blocking.strides[0][4] = weights->stridesOf()[2];
}
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format::any);
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
user_diff_weights_md->data.format = mkldnn_blocked; // overrides "formatw = dhwio"
user_diff_weights_md->data.layout_desc.blocking.strides[0][0] = diff_weights->stridesOf()[4];
user_diff_weights_md->data.layout_desc.blocking.strides[0][1] = diff_weights->stridesOf()[3];
user_diff_weights_md->data.layout_desc.blocking.strides[0][2] = diff_weights->stridesOf()[0];
user_diff_weights_md->data.layout_desc.blocking.strides[0][3] = diff_weights->stridesOf()[1];
user_diff_weights_md->data.layout_desc.blocking.strides[0][4] = diff_weights->stridesOf()[2];
}
if (bias != nullptr && conv_bias_md != nullptr) {
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format::any);
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format::x);
}
if (dst != nullptr && conv_dst_md != nullptr) {
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format::any);
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format);
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCDHW ? 0 : 0];
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCDHW ? 1 : 4];
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCDHW ? 2 : 1];
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCDHW ? 3 : 2];
user_dst_md->data.layout_desc.blocking.strides[0][4] = dst->stridesOf()[isNCDHW ? 4 : 3];
}
}
#endif
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
@ -543,83 +260,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
if(isSameMode) // SAME if(isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
#ifdef HAVE_MKLDNN
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<X, Y>()) {
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
if (streams.empty()) {
streams.push_back(MKLDNNStream("conv2d"));
}
if (streams[0].checkAndReset({input, weights, bias}, {output}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW})) {
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
ConvolutionUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, bias, output,
&conv_src_md, nullptr, &conv_weights_md, nullptr, &conv_bias_md, &conv_dst_md,
&user_src_md, nullptr, &user_weights_md, nullptr, &user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r);
auto conv_desc = bias != nullptr
? convolution_forward::desc(prop_kind::forward,
convolution_direct, conv_src_md, conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
: convolution_forward::desc(prop_kind::forward,
convolution_direct, conv_src_md, conv_weights_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
auto engine = streams[0].getEngine();
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray*>(input)->buffer());
auto user_weights_memory = mkldnn::memory({user_weights_md, engine}, const_cast<NDArray*>(weights)->buffer());
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output->buffer());
auto conv_src_memory = user_src_memory;
streams[0].addMemory(user_src_memory);
if (mkldnn::memory::primitive_desc(conv_prim_desc.src_primitive_desc())
!= user_src_memory.get_primitive_desc()) {
conv_src_memory = mkldnn::memory(conv_prim_desc.src_primitive_desc());
streams[0].addMemory(conv_src_memory);
streams[0].addOperation(reorder(user_src_memory, conv_src_memory));
}
auto conv_weights_memory = user_weights_memory;
streams[0].addMemory(user_weights_memory);
if (mkldnn::memory::primitive_desc(conv_prim_desc.weights_primitive_desc())
!= user_weights_memory.get_primitive_desc()) {
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_primitive_desc());
streams[0].addMemory(conv_weights_memory);
streams[0].addOperation(reorder(user_weights_memory, conv_weights_memory));
}
auto conv_dst_memory = user_dst_memory;
streams[0].addMemory(user_dst_memory);
if (mkldnn::memory::primitive_desc(conv_prim_desc.dst_primitive_desc())
!= user_dst_memory.get_primitive_desc()) {
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_primitive_desc());
streams[0].addMemory(conv_dst_memory);
}
if (bias != nullptr) {
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_primitive_desc(), const_cast<NDArray*>(bias)->buffer());
streams[0].addMemory(conv_bias_memory);
streams[0].addOperation(convolution_forward(conv_prim_desc, conv_src_memory, conv_weights_memory, conv_bias_memory, conv_dst_memory));
} else {
streams[0].addOperation(convolution_forward(conv_prim_desc, conv_src_memory, conv_weights_memory, conv_dst_memory));
}
if (mkldnn::memory::primitive_desc(conv_prim_desc.dst_primitive_desc())
!= user_dst_memory.get_primitive_desc()) {
streams[0].addOperation(reorder(conv_dst_memory, user_dst_memory));
}
}
streams[0].submitAndWait();
return;
}
#endif
nd4j_debug("MKL-DNN is not used for conv2d!\n", 0); nd4j_debug("MKL-DNN is not used for conv2d!\n", 0);
std::vector<int> permutForOutput; std::vector<int> permutForOutput;
@ -686,151 +326,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
if(isSameMode) // SAME if(isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
#ifdef HAVE_MKLDNN
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<X, Y>()) {
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
if (streams.empty()) {
streams.push_back(MKLDNNStream("conv2d_bp_weights"));
streams.push_back(MKLDNNStream("conv2d_bp_data"));
}
bool resetW = streams[0].checkAndReset({input, weights, bias, gradO}, {gradI, gradW, gradB}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW});
bool resetI = streams[1].checkAndReset({input, weights, bias, gradO}, {gradI, gradW, gradB}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW});
if (resetW || resetI) {
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
ConvolutionUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, gradB, gradO,
&conv_src_md, &conv_diff_src_md, &conv_weights_md, &conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
&user_src_md, &user_diff_src_md, &user_weights_md, &user_diff_weights_md, &user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r);
auto conv_desc = gradB != nullptr
? convolution_forward::desc(prop_kind::forward,
convolution_direct, conv_src_md, conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
: convolution_forward::desc(prop_kind::forward,
convolution_direct, conv_src_md, conv_weights_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, streams[0].getEngine());
if (gradW != nullptr) {
auto convW_desc = gradB != nullptr
? convolution_backward_weights::desc(
convolution_direct, conv_src_md, conv_diff_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero)
: convolution_backward_weights::desc(
convolution_direct, conv_src_md, conv_diff_weights_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
auto engine = streams[0].getEngine();
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc);
auto userW_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray*>(input)->buffer());
auto userW_weights_memory = mkldnn::memory({user_diff_weights_md, engine}, gradW->buffer());
auto userW_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray*>(gradO)->buffer());
auto convW_src_memory = userW_src_memory;
streams[0].addMemory(userW_src_memory);
if (mkldnn::memory::primitive_desc(convW_prim_desc.src_primitive_desc())
!= userW_src_memory.get_primitive_desc()) {
convW_src_memory = mkldnn::memory(convW_prim_desc.src_primitive_desc());
streams[0].addMemory(convW_src_memory);
streams[0].addOperation(reorder(userW_src_memory, convW_src_memory));
}
auto convW_weights_memory = userW_weights_memory;
streams[0].addMemory(userW_weights_memory);
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_weights_primitive_desc())
!= userW_weights_memory.get_primitive_desc()) {
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_primitive_desc());
streams[0].addMemory(convW_weights_memory);
}
auto convW_dst_memory = userW_dst_memory;
streams[0].addMemory(userW_dst_memory);
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_dst_primitive_desc())
!= userW_dst_memory.get_primitive_desc()) {
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_primitive_desc());
streams[0].addMemory(convW_dst_memory);
streams[0].addOperation(reorder(userW_dst_memory, convW_dst_memory));
}
if (gradB != nullptr) {
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_primitive_desc(), gradB->buffer());
streams[0].addMemory(convW_bias_memory);
streams[0].addOperation(convolution_backward_weights(convW_prim_desc, convW_src_memory, convW_dst_memory, convW_weights_memory, convW_bias_memory));
} else {
streams[0].addOperation(convolution_backward_weights(convW_prim_desc, convW_src_memory, convW_dst_memory, convW_weights_memory));
}
if (mkldnn::memory::primitive_desc(convW_prim_desc.diff_weights_primitive_desc())
!= userW_weights_memory.get_primitive_desc()) {
streams[0].addOperation(reorder(convW_weights_memory, userW_weights_memory));
}
}
if (gradI != nullptr) {
auto convI_desc =
convolution_backward_data::desc(
convolution_direct, conv_diff_src_md, conv_weights_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r, padding_kind::zero);
auto engine = streams[1].getEngine();
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc);
auto userI_src_memory = mkldnn::memory({user_diff_src_md, engine}, gradI->buffer());
auto userI_weights_memory = mkldnn::memory({user_weights_md, engine}, const_cast<NDArray*>(weights)->buffer());
auto userI_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray*>(gradO)->buffer());
auto convI_src_memory = userI_src_memory;
streams[1].addMemory(userI_src_memory);
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_src_primitive_desc())
!= userI_src_memory.get_primitive_desc()) {
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_primitive_desc());
streams[1].addMemory(convI_src_memory);
}
auto convI_weights_memory = userI_weights_memory;
streams[1].addMemory(userI_weights_memory);
if (mkldnn::memory::primitive_desc(convI_prim_desc.weights_primitive_desc())
!= userI_weights_memory.get_primitive_desc()) {
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_primitive_desc());
streams[1].addMemory(convI_weights_memory);
streams[1].addOperation(reorder(userI_weights_memory, convI_weights_memory));
}
auto convI_dst_memory = userI_dst_memory;
streams[1].addMemory(userI_dst_memory);
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_dst_primitive_desc())
!= userI_dst_memory.get_primitive_desc()) {
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_primitive_desc());
streams[1].addMemory(convI_dst_memory);
streams[1].addOperation(reorder(userI_dst_memory, convI_dst_memory));
}
streams[1].addOperation(convolution_backward_data(convI_prim_desc, convI_dst_memory, convI_weights_memory, convI_src_memory));
if (mkldnn::memory::primitive_desc(convI_prim_desc.diff_src_primitive_desc())
!= userI_src_memory.get_primitive_desc()) {
streams[1].addOperation(reorder(convI_src_memory, userI_src_memory));
}
}
}
if (gradW != nullptr) {
streams[0].submitAndWait();
}
if (gradI != nullptr) {
streams[1].submitAndWait();
}
return;
}
#endif
nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0); nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0);
std::vector<int> gradOaxesForDot; std::vector<int> gradOaxesForDot;
@ -1268,62 +763,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
const int oH = output.sizeAt(2); const int oH = output.sizeAt(2);
const int oW = output.sizeAt(3); const int oW = output.sizeAt(3);
#ifdef HAVE_MKLDNN
if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<T, T>()) {
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
if (streams.empty()) {
streams.push_back(MKLDNNStream("pooling2d"));
}
if (streams[0].checkAndReset({&input}, {&output}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0})) {
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
ConvolutionUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, true,
bS, iC, iH, iW, oC, oH, oW, &input, nullptr, &output, algorithm,
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, &user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
auto engine = streams[0].getEngine();
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray&>(input).buffer());
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output.buffer());
auto pool_src_memory = user_src_memory;
streams[0].addMemory(user_src_memory);
if (mkldnn::memory::primitive_desc(pool_prim_desc.src_primitive_desc())
!= user_src_memory.get_primitive_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_primitive_desc());
streams[0].addMemory(pool_src_memory);
streams[0].addOperation(reorder(user_src_memory, pool_src_memory));
}
auto pool_dst_memory = user_dst_memory;
streams[0].addMemory(user_dst_memory);
if (mkldnn::memory::primitive_desc(pool_prim_desc.dst_primitive_desc())
!= user_dst_memory.get_primitive_desc()) {
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_primitive_desc());
streams[0].addMemory(pool_dst_memory);
}
streams[0].addOperation(pooling_forward(pool_prim_desc, pool_src_memory, pool_dst_memory));
if (mkldnn::memory::primitive_desc(pool_prim_desc.dst_primitive_desc())
!= user_dst_memory.get_primitive_desc()) {
streams[0].addOperation(reorder(pool_dst_memory, user_dst_memory));
}
}
streams[0].submitAndWait();
return;
}
#endif
nd4j_debug("MKL-DNN is not used for pooling2d!\n", 0); nd4j_debug("MKL-DNN is not used for pooling2d!\n", 0);
const Nd4jLong iStride0 = input.stridesOf()[0]; const Nd4jLong iStride0 = input.stridesOf()[0];
@ -1504,62 +943,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
const int oH = output.sizeAt(3); const int oH = output.sizeAt(3);
const int oW = output.sizeAt(4); const int oW = output.sizeAt(4);
#ifdef HAVE_MKLDNN
if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<T, T>()) {
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
if (streams.empty()) {
streams.push_back(MKLDNNStream("pooling3d"));
}
if (streams[0].checkAndReset({&input}, {&output}, {}, {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0})) {
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
ConvolutionUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0, true,
bS, iC, iD, iH, iW, oC, oD, oH, oW, &input, nullptr, &output, algorithm,
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, &user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
auto engine = streams[0].getEngine();
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray&>(input).buffer());
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output.buffer());
auto pool_src_memory = user_src_memory;
streams[0].addMemory(user_src_memory);
if (mkldnn::memory::primitive_desc(pool_prim_desc.src_primitive_desc())
!= user_src_memory.get_primitive_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_primitive_desc());
streams[0].addMemory(pool_src_memory);
streams[0].addOperation(reorder(user_src_memory, pool_src_memory));
}
auto pool_dst_memory = user_dst_memory;
streams[0].addMemory(user_dst_memory);
if (mkldnn::memory::primitive_desc(pool_prim_desc.dst_primitive_desc())
!= user_dst_memory.get_primitive_desc()) {
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_primitive_desc());
streams[0].addMemory(pool_dst_memory);
}
streams[0].addOperation(pooling_forward(pool_prim_desc, pool_src_memory, pool_dst_memory));
if (mkldnn::memory::primitive_desc(pool_prim_desc.dst_primitive_desc())
!= user_dst_memory.get_primitive_desc()) {
streams[0].addOperation(reorder(pool_dst_memory, user_dst_memory));
}
}
streams[0].submitAndWait();
return;
}
#endif
nd4j_debug("MKL-DNN is not used for pooling3d!\n", 0); nd4j_debug("MKL-DNN is not used for pooling3d!\n", 0);
const Nd4jLong iStride0 = input.stridesOf()[0]; const Nd4jLong iStride0 = input.stridesOf()[0];
@ -1776,91 +1159,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
const int oH = gradO.sizeAt(2); const int oH = gradO.sizeAt(2);
const int oW = gradO.sizeAt(3); const int oW = gradO.sizeAt(3);
#ifdef HAVE_MKLDNN
if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<T, T>()) {
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
if (streams.empty()) {
streams.push_back(MKLDNNStream("pooling2d_bp"));
}
if (streams[0].checkAndReset({&input, &gradO}, {&gradI}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0})) {
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
ConvolutionUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, true,
bS, iC, iH, iW, oC, oH, oW, &input, &gradI, &gradO, algorithm,
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, &user_diff_src_md, &user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
// input is sometimes null, so we can't rely on pool_src_md being valid
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
const_cast<NDArray&>(input).buffer() != nullptr ? pool_src_md : pool_diff_src_md,
pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
auto engine = streams[0].getEngine();
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = mkldnn::memory({user_src_md, engine}, gradI.buffer());
auto userB_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray&>(gradO).buffer());
auto poolB_src_memory = userB_src_memory;
streams[0].addMemory(userB_src_memory);
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_src_primitive_desc())
!= userB_src_memory.get_primitive_desc()) {
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_primitive_desc());
streams[0].addMemory(poolB_src_memory);
}
auto poolB_dst_memory = userB_dst_memory;
streams[0].addMemory(userB_dst_memory);
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_dst_primitive_desc())
!= userB_dst_memory.get_primitive_desc()) {
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_primitive_desc());
streams[0].addMemory(poolB_dst_memory);
streams[0].addOperation(reorder(userB_dst_memory, poolB_dst_memory));
}
if (algorithm == mkldnn::pooling_max) {
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray&>(input).buffer());
auto pool_src_memory = user_src_memory;
streams[0].addMemory(user_src_memory);
if (mkldnn::memory::primitive_desc(pool_prim_desc.src_primitive_desc())
!= user_src_memory.get_primitive_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_primitive_desc());
streams[0].addMemory(pool_src_memory);
streams[0].addOperation(reorder(user_src_memory, pool_src_memory));
}
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_primitive_desc());
streams[0].addMemory(pool_dst_memory);
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_primitive_desc());
streams[0].addMemory(pool_workspace_memory);
streams[0].addOperation(pooling_forward(pool_prim_desc, pool_src_memory, pool_dst_memory, pool_workspace_memory));
streams[0].addOperation(pooling_backward(poolB_prim_desc, poolB_dst_memory, pool_workspace_memory, poolB_src_memory));
} else {
streams[0].addOperation(pooling_backward(poolB_prim_desc, poolB_dst_memory, poolB_src_memory));
}
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_src_primitive_desc())
!= userB_src_memory.get_primitive_desc()) {
streams[0].addOperation(reorder(poolB_src_memory, userB_src_memory));
}
}
streams[0].submitAndWait();
return;
}
#endif
nd4j_debug("MKL-DNN is not used for pooling2d_bp!\n", 0); nd4j_debug("MKL-DNN is not used for pooling2d_bp!\n", 0);
const Nd4jLong iStride0 = input.stridesOf()[0]; const Nd4jLong iStride0 = input.stridesOf()[0];
@ -2099,94 +1397,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
const int oH = gradO.sizeAt(3); const int oH = gradO.sizeAt(3);
const int oW = gradO.sizeAt(4); const int oW = gradO.sizeAt(4);
#ifdef HAVE_MKLDNN
if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported<T, T>()) {
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
if (streams.empty()) {
streams.push_back(MKLDNNStream("pooling3d_bp"));
}
if (streams[0].checkAndReset({&input, &gradO}, {&gradI}, {}, {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0})) {
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
ConvolutionUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0, true,
bS, iC, iD, iH, iW, oC, oD, oH, oW, &input, &gradI, &gradO, algorithm,
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, &user_diff_src_md, &user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
// input is sometimes null, so we can't rely on pool_src_md being valid
if (const_cast<NDArray&>(input).buffer() == nullptr) {
pool_src_md = pool_diff_src_md;
user_src_md = user_diff_src_md;
}
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md,
pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
auto engine = streams[0].getEngine();
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r, padding_kind::zero);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = mkldnn::memory({user_diff_src_md, engine}, gradI.buffer());
auto userB_dst_memory = mkldnn::memory({user_dst_md, engine}, const_cast<NDArray&>(gradO).buffer());
auto poolB_src_memory = userB_src_memory;
streams[0].addMemory(userB_src_memory);
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_src_primitive_desc())
!= userB_src_memory.get_primitive_desc()) {
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_primitive_desc());
streams[0].addMemory(poolB_src_memory);
}
auto poolB_dst_memory = userB_dst_memory;
streams[0].addMemory(userB_dst_memory);
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_dst_primitive_desc())
!= userB_dst_memory.get_primitive_desc()) {
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_primitive_desc());
streams[0].addMemory(poolB_dst_memory);
streams[0].addOperation(reorder(userB_dst_memory, poolB_dst_memory));
}
if (algorithm == mkldnn::pooling_max) {
auto user_src_memory = mkldnn::memory({user_src_md, engine}, const_cast<NDArray&>(input).buffer());
auto pool_src_memory = user_src_memory;
streams[0].addMemory(user_src_memory);
if (mkldnn::memory::primitive_desc(pool_prim_desc.src_primitive_desc())
!= user_src_memory.get_primitive_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_primitive_desc());
streams[0].addMemory(pool_src_memory);
streams[0].addOperation(reorder(user_src_memory, pool_src_memory));
}
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_primitive_desc());
streams[0].addMemory(pool_dst_memory);
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_primitive_desc());
streams[0].addMemory(pool_workspace_memory);
streams[0].addOperation(pooling_forward(pool_prim_desc, pool_src_memory, pool_dst_memory, pool_workspace_memory));
streams[0].addOperation(pooling_backward(poolB_prim_desc, poolB_dst_memory, pool_workspace_memory, poolB_src_memory));
} else {
streams[0].addOperation(pooling_backward(poolB_prim_desc, poolB_dst_memory, poolB_src_memory));
}
if (mkldnn::memory::primitive_desc(poolB_prim_desc.diff_src_primitive_desc())
!= userB_src_memory.get_primitive_desc()) {
streams[0].addOperation(reorder(poolB_src_memory, userB_src_memory));
}
}
streams[0].submitAndWait();
return;
}
#endif
nd4j_debug("MKL-DNN is not used for pooling3d_bp!\n", 0); nd4j_debug("MKL-DNN is not used for pooling3d_bp!\n", 0);
const Nd4jLong iStride0 = input.stridesOf()[0]; const Nd4jLong iStride0 = input.stridesOf()[0];

View File

@ -27,107 +27,9 @@ namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
#ifdef HAVE_MKLDNN
using namespace mkldnn;
static void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* lrn_src_md, mkldnn::memory::desc* lrn_diff_src_md, mkldnn::memory::desc* lrn_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
const Nd4jLong* shape = src->getShapeInfo();
long rank = shape[0];
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
long dim2 = axis >= 2 ? 1 : 2;
long dim3 = axis >= 3 ? 2 : 3;
mkldnn::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
auto type = mkldnn::memory::data_type::f32;
auto format = axis == 1 ? mkldnn::memory::format::nchw : mkldnn::memory::format::nhwc;
auto supposed_to_be_any_format = format; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
*lrn_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
user_src_md->data.format = mkldnn_blocked;
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[0];
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[dim1];
user_src_md->data.layout_desc.blocking.strides[0][2] = rank > 2 ? src->stridesOf()[dim2] : 1;
user_src_md->data.layout_desc.blocking.strides[0][3] = rank > 3 ? src->stridesOf()[dim3] : 1;
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
*lrn_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
user_diff_src_md->data.format = mkldnn_blocked;
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[0];
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[dim1];
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
}
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
*lrn_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
user_dst_md->data.format = mkldnn_blocked;
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[0];
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[dim1];
user_dst_md->data.layout_desc.blocking.strides[0][2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
user_dst_md->data.layout_desc.blocking.strides[0][3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
}
}
#endif
template <typename T> template <typename T>
static int lrnFunctor_(nd4j::graph::Context& block, NDArray* input, NDArray* output, int depth, float bias, float alpha, float beta) { static int lrnFunctor_(nd4j::graph::Context& block, NDArray* input, NDArray* output, int depth, float bias, float alpha, float beta) {
#ifdef HAVE_MKLDNN
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output})) {
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
if (streams.empty()) {
streams.push_back(MKLDNNStream("lrn"));
}
if (streams[0].checkAndReset({input}, {output}, {(float)bias, (float)alpha, (float)beta}, {depth})) {
mkldnn_memory_desc_t empty;
mkldnn::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty);
getMKLDNNMemoryDescLrn(input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md, &user_src_md, nullptr, &user_dst_md, input->rankOf() - 1);
auto lrn_desc = lrn_forward::desc(prop_kind::forward_inference, lrn_across_channels, lrn_src_md, (2 * depth + 1), alpha * (2 * depth + 1), beta, bias);
auto engine = streams[0].getEngine();
auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine);
auto user_src_memory = mkldnn::memory({user_src_md, engine}, input->buffer());
auto user_dst_memory = mkldnn::memory({user_dst_md, engine}, output->buffer());
auto lrn_src_memory = user_src_memory;
streams[0].addMemory(user_src_memory);
if (mkldnn::memory::primitive_desc(lrn_prim_desc.src_primitive_desc())
!= user_src_memory.get_primitive_desc()) {
lrn_src_memory = mkldnn::memory(lrn_prim_desc.src_primitive_desc());
streams[0].addMemory(lrn_src_memory);
streams[0].addOperation(reorder(user_src_memory, lrn_src_memory));
}
auto lrn_dst_memory = user_dst_memory;
streams[0].addMemory(user_dst_memory);
if (mkldnn::memory::primitive_desc(lrn_prim_desc.dst_primitive_desc())
!= user_dst_memory.get_primitive_desc()) {
lrn_dst_memory = mkldnn::memory(lrn_prim_desc.dst_primitive_desc());
streams[0].addMemory(lrn_dst_memory);
}
streams[0].addOperation(lrn_forward(lrn_prim_desc, lrn_src_memory, lrn_dst_memory));
if (mkldnn::memory::primitive_desc(lrn_prim_desc.dst_primitive_desc())
!= user_dst_memory.get_primitive_desc()) {
streams[0].addOperation(reorder(lrn_dst_memory, user_dst_memory));
}
}
streams[0].submitAndWait();
return ND4J_STATUS_OK;
}
#endif
nd4j_debug("MKL-DNN is not used for lrn!\n", 0); nd4j_debug("MKL-DNN is not used for lrn!\n", 0);
const int rank = input->rankOf(); const int rank = input->rankOf();

View File

@ -24,6 +24,7 @@
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
#include <exceptions/graph_exception.h> #include <exceptions/graph_exception.h>
#include <exceptions/unresolved_input_exception.h> #include <exceptions/unresolved_input_exception.h>
#include <ops/declarable/OpRegistrator.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -501,7 +502,22 @@ namespace nd4j {
prepTime = std::chrono::duration_cast<std::chrono::nanoseconds>(timeStart - timeEnter).count(); prepTime = std::chrono::duration_cast<std::chrono::nanoseconds>(timeStart - timeEnter).count();
} }
Nd4jStatus status = this->validateAndExecute(*block);
Nd4jStatus status;
bool hasHelper = false;
// if we have platform-specific helper for this op - invoke it
if (OpRegistrator::getInstance()->hasHelper(this->getOpHash())) {
auto helper = OpRegistrator::getInstance()->getPlatformHelper(this->getOpHash());
if (helper->isUsable(*block)) {
status = helper->invokeHelper(*block);
hasHelper = true;
}
}
// if we don't have platform-specific helper - invoke generic implementation
if (!hasHelper)
status = this->validateAndExecute(*block);
// optionally saving execution time // optionally saving execution time
if (Environment::getInstance()->isProfiling()) { if (Environment::getInstance()->isProfiling()) {

View File

@ -113,8 +113,13 @@ namespace nd4j {
for (auto x : _uniqueD) for (auto x : _uniqueD)
delete x; delete x;
for (auto x: _uniqueH)
delete x;
_uniqueD.clear(); _uniqueD.clear();
_uniqueH.clear();
_declarablesD.clear(); _declarablesD.clear();
_declarablesLD.clear(); _declarablesLD.clear();
@ -144,6 +149,8 @@ namespace nd4j {
return _opsList.c_str(); return _opsList.c_str();
} }
bool OpRegistrator::registerOperation(const char* name, nd4j::ops::DeclarableOp* op) { bool OpRegistrator::registerOperation(const char* name, nd4j::ops::DeclarableOp* op) {
std::string str(name); std::string str(name);
std::pair<std::string, nd4j::ops::DeclarableOp*> pair(str, op); std::pair<std::string, nd4j::ops::DeclarableOp*> pair(str, op);
@ -165,6 +172,19 @@ namespace nd4j {
return registerOperation(op->getOpName()->c_str(), op); return registerOperation(op->getOpName()->c_str(), op);
} }
void OpRegistrator::registerHelper(nd4j::ops::platforms::PlatformHelper* op) {
if (_helpersLH.count(op->hash()) > 0)
throw std::runtime_error("Tried to double register PlatformHelper");
_uniqueH.emplace_back(op);
std::pair<std::string, nd4j::ops::platforms::PlatformHelper*> pair(op->name(), op);
_helpersH.insert(pair);
std::pair<Nd4jLong, nd4j::ops::platforms::PlatformHelper*> pair2(op->hash(), op);
_helpersLH.insert(pair2);
}
nd4j::ops::DeclarableOp* OpRegistrator::getOperation(const char *name) { nd4j::ops::DeclarableOp* OpRegistrator::getOperation(const char *name) {
std::string str(name); std::string str(name);
return getOperation(str); return getOperation(str);
@ -207,6 +227,16 @@ namespace nd4j {
return _declarablesD.at(name); return _declarablesD.at(name);
} }
nd4j::ops::platforms::PlatformHelper* OpRegistrator::getPlatformHelper(Nd4jLong hash) {
if (_helpersLH.count(hash) == 0)
throw std::runtime_error("Requested helper can't be found");
return _helpersLH[hash];
}
bool OpRegistrator::hasHelper(Nd4jLong hash) {
return _helpersLH.count(hash) > 0;
}
int OpRegistrator::numberOfOperations() { int OpRegistrator::numberOfOperations() {
return (int) _declarablesLD.size(); return (int) _declarablesLD.size();

View File

@ -0,0 +1,86 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../PlatformHelper.h"
#include <graph/Variable.h>
namespace nd4j {
namespace ops {
namespace platforms {
PlatformHelper::PlatformHelper(const char *name) {
// we just store name/hash of target operation
_name = std::string(name);
_hash = HashHelper::getInstance()->getLongHash(_name);
}
nd4j::NDArray *PlatformHelper::getZ(graph::Context &ctx, int inputId) {
NDArray *z = nullptr;
if (ctx.isFastPath()) {
if (ctx.fastpath_out().size() <= inputId) {
if (ctx.isInplace()) {
z = ctx.fastpath_in()[inputId];
} else
throw std::runtime_error("fastpath_out: unresolved output array");
} else {
z = ctx.fastpath_out()[inputId];
}
} else {
std::pair<int, int> pair(ctx.nodeId(), inputId);
if (ctx.isInplace()) {
z = ctx.variable(inputId)->getNDArray();
// hypothetically it's possible to have no variable. chances are low, but who knows. let's just create it for now
if (!ctx.getVariableSpace()->hasVariable(pair)) {
auto var = new graph::Variable();
ctx.getVariableSpace()->putVariable(pair, var);
}
// now we're saving input array as output array
auto var = ctx.getVariableSpace()->getVariable(pair);
var->markRemovable(false);
var->setNDArray(z);
} else if (!ctx.isInplace()) {
auto var = ctx.variable(pair);
if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) {
z = var->getNDArray();
} else {
nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId());
}
} else {
nd4j_printf("BOOM!\n", "");
throw std::runtime_error("Boom!");
}
}
return z;
}
std::string PlatformHelper::name() {
return _name;
}
Nd4jLong PlatformHelper::hash() {
return _hash;
}
}
}
}

View File

@ -0,0 +1 @@
This folder contains platform-specific optimized implementations for custom operations

View File

@ -0,0 +1,143 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(avgpool2d) {
auto input = INPUT_VARIABLE(0);
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
input->rankOf());
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
auto argI = *(block.getIArguments());
auto output = OUTPUT_VARIABLE(0);
const auto kH = INT_ARG(0);
const auto kW = INT_ARG(1);
const auto sH = INT_ARG(2);
const auto sW = INT_ARG(3);
int pH = INT_ARG(4);
int pW = INT_ARG(5);
const auto dH = INT_ARG(6);
const auto dW = INT_ARG(7);
const auto isSameMode = static_cast<bool>(INT_ARG(8));
const auto extraParam0 = INT_ARG(9);
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
dH, dW);
int oH = 0;
int oW = 0;
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
if (!isNCHW) {
input = new NDArray(
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = new NDArray(
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
if (isSameMode)
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
const int bS = input->sizeAt(0);
const int iC = input->sizeAt(1);
const int oC = output->sizeAt(1);
auto poolingMode = PoolingType::AVG_POOL;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
algorithm,
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
&user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory;
mkldnn::stream stream(engine);
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
}
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
}
stream.wait();
//streams[0].submitAndWait();
if (!isNCHW) {
delete input;
delete output;
}
return Status::OK();
}
PLATFORM_CHECK(avgpool2d) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -0,0 +1,153 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(avgpool2d_bp) {
auto input = INPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradO = INPUT_VARIABLE(
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int extraParam0 = INT_ARG(9);
int isNCHW =
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0,
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
indIiH, indWiC, indWoC, indWkH, indOoH);
std::string expectedGradOShape = ShapeUtils::shapeAsString(
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if (!isNCHW) {
input = new NDArray(input->permute(
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = new NDArray(gradI->permute(
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = new NDArray(gradO->permute(
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
if (isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
auto poolingMode = PoolingType::AVG_POOL;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
&user_diff_src_md, &user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
pool_dst_md, pool_strides, pool_kernel, pool_padding,
pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = mkldnn::memory(user_src_md, engine, gradI->buffer());
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory;
mkldnn::stream stream(engine);
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
}
auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
}
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
}
stream.wait();
if (!isNCHW) {
delete input;
delete gradI;
delete gradO;
}
return Status::OK();
}
PLATFORM_CHECK(avgpool2d_bp) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -0,0 +1,145 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(avgpool3dnew) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto output = OUTPUT_VARIABLE(
0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
int kD = INT_ARG(0); // filter(kernel) depth
int kH = INT_ARG(1); // filter(kernel) height
int kW = INT_ARG(2); // filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
REQUIRE_TRUE(input->rankOf() == 5, 0,
"MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
"MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
if (!isNCDHW) {
input = new NDArray(
input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
output = new NDArray(
output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
}
if (isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
auto poolingMode = PoolingType::AVG_POOL;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
algorithm,
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
&user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
}
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
}
stream.wait();
if (!isNCDHW) {
delete input;
delete output;
}
return Status::OK();
}
PLATFORM_CHECK(avgpool3dnew) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -0,0 +1,158 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(avgpool3dnew_bp) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto gradO = INPUT_VARIABLE(
1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
const int kD = INT_ARG(0); // filter(kernel) depth
const int kH = INT_ARG(1); // filter(kernel) height
const int kW = INT_ARG(2); // filter(kernel) width
const int sD = INT_ARG(3); // strides depth
const int sH = INT_ARG(4); // strides height
const int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
const int dD = INT_ARG(9); // dilations depth
const int dH = INT_ARG(10); // dilations height
const int dW = INT_ARG(11); // dilations width
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
REQUIRE_TRUE(input->rankOf() == 5, 0,
"MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
"MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
"MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if (!isNCDHW) {
input = new NDArray(input->permute(
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = new NDArray(gradI->permute(
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradO = new NDArray(gradO->permute(
{0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
}
if (isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
auto poolingMode = PoolingType::AVG_POOL;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
algorithm,
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
&user_diff_src_md, &user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
if (input->buffer() == nullptr) {
pool_src_md = pool_diff_src_md;
user_src_md = user_diff_src_md;
}
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory;
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
}
auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
}
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
}
stream.wait();
if (!isNCDHW) {
delete input;
delete gradI;
delete gradO;
}
return Status::OK();
}
PLATFORM_CHECK(avgpool3dnew_bp) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -0,0 +1,166 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
#include <NDArrayFactory.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(batchnorm_new) {
auto input = INPUT_VARIABLE(0);
auto mean = INPUT_VARIABLE(1);
auto variance = INPUT_VARIABLE(2);
NDArray *gamma = nullptr;
NDArray *beta = nullptr;
auto output = OUTPUT_VARIABLE(0);
const bool applyScale = (bool) INT_ARG(0);
const bool applyOffset = (bool) INT_ARG(1);
const double epsilon = T_ARG(0);
if (applyScale)
gamma = INPUT_VARIABLE(3);
if (applyOffset)
beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
std::vector<int> axes;
if (block.numI() > 2)
for (int i = 2; i < block.numI(); ++i)
axes.push_back(INT_ARG(i));
else
axes.push_back(input->rankOf() - 1);
std::vector<Nd4jLong> shape({2, mean->lengthOf()});
NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext());
weights({0, 1, 0, 0}).assign(1.0f);
weights({1, 2, 0, 0}).assign(0.0f);
mkldnn_memory_desc_t empty;
mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(
empty), user_dst_md(empty);
auto norm_flag = normalization_flags::use_global_stats;
if (applyScale || applyOffset)
norm_flag |= normalization_flags::use_scale_shift;
mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output,
&batchnorm_src_md, nullptr, &batchnorm_dst_md,
&user_src_md, nullptr, &user_dst_md, axes[0]);
auto batchnorm_desc = batch_normalization_forward::desc(prop_kind::forward_inference, batchnorm_src_md, epsilon, norm_flag);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto batchnorm_prim_desc = batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine,
mean->buffer());
auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine,
variance->buffer());
auto batchnorm_src_memory = user_src_memory;
mkldnn::memory m(batchnorm_src_md, engine);
if (m.get_desc() != user_src_memory.get_desc()) {
batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine);
reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory,
batchnorm_src_memory);
}
auto batchnorm_dst_memory = user_dst_memory;
if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine);
}
if (applyScale || applyOffset) {
if (gamma != nullptr) {
weights({0, 1, 0, 0}).assign(gamma);
}
if (beta != nullptr) {
weights({1, 2, 0, 0}).assign(beta);
}
auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer());
batch_normalization_forward(batchnorm_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, batchnorm_src_memory},
{MKLDNN_ARG_MEAN, batchnorm_mean_memory},
{MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
{MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory},
{MKLDNN_ARG_DST, batchnorm_dst_memory}});
} else {
batch_normalization_forward(batchnorm_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, batchnorm_src_memory},
{MKLDNN_ARG_MEAN, batchnorm_mean_memory},
{MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
{MKLDNN_ARG_DST, batchnorm_dst_memory}});
}
if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory,
user_dst_memory);
}
stream.wait();
return Status::OK();
}
PLATFORM_CHECK(batchnorm_new) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0);
auto mean = INPUT_VARIABLE(1);
auto variance = INPUT_VARIABLE(2);
NDArray *gamma = nullptr;
NDArray *beta = nullptr;
auto output = OUTPUT_VARIABLE(0);
const bool applyScale = (bool) INT_ARG(0);
const bool applyOffset = (bool) INT_ARG(1);
const double epsilon = T_ARG(0);
if (applyScale)
gamma = INPUT_VARIABLE(3);
if (applyOffset)
beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
std::vector<int> axes;
if (block.numI() > 2)
for (int i = 2; i < block.numI(); ++i)
axes.push_back(INT_ARG(i));
else
axes.push_back(input->rankOf() - 1);
return block.isUseMKLDNN() &&
nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) &&
axes.size() == 1;
}
}
}
}

View File

@ -0,0 +1,153 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode,
const int isNCHW) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW,
indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
if(isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
empty);
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
bias, output,
&conv_src_md, nullptr, &conv_weights_md, nullptr,
&conv_bias_md, &conv_dst_md,
&user_src_md, nullptr, &user_weights_md, nullptr,
&user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r);
auto conv_desc = bias != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto user_weights_memory = mkldnn::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto conv_src_memory = user_src_memory;
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine);
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
}
auto conv_weights_memory = user_weights_memory;
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine);
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
conv_weights_memory);
}
auto conv_dst_memory = user_dst_memory;
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine);
}
if (bias != nullptr) {
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine,
const_cast<NDArray *>(bias)->buffer());
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_BIAS, conv_bias_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
} else {
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
}
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
}
stream.wait();
}
PLATFORM_IMPL(conv2d) {
auto input = INPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(
0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW);
return Status::OK();
}
PLATFORM_CHECK(conv2d) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1);
// conv2d is only available for float32 dtype
return block.isUseMKLDNN() && input->dataType() == nd4j::DataType::FLOAT32 &&
weights->dataType() == nd4j::DataType::FLOAT32;
}
}
}
}

View File

@ -0,0 +1,243 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(conv2d_bp) {
auto input = INPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
auto gradW = OUTPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0,
"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 4, 0,
"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",
weights->rankOf());
REQUIRE_TRUE(gradO->rankOf() == 4, 0,
"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",
gradO->rankOf());
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
indIiH, indWiC, indWoC, indWkH, indOoH);
if (isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
gradB, gradO,
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
&user_src_md, &user_diff_src_md, &user_weights_md,
&user_diff_weights_md, &user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r);
auto conv_desc = gradB != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
LaunchContext::defaultContext()->engine()));
if (gradW != nullptr) {
auto convW_desc = gradB != nullptr
? convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r)
: convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
conv_prim_desc);
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
const_cast<NDArray *>(input)->buffer());
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convW_src_memory = userW_src_memory;
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
convW_src_memory);
}
auto convW_weights_memory = userW_weights_memory;
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
}
auto convW_dst_memory = userW_dst_memory;
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
convW_dst_memory);
}
if (gradB != nullptr) {
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
gradB->buffer());
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
} else {
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
}
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
userW_weights_memory);
}
stream.wait();
}
if (gradI != nullptr) {
auto convI_desc =
convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md,
conv_weights_md, conv_dst_md, conv_strides,
conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
conv_prim_desc);
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convI_src_memory = userI_src_memory;
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
}
auto convI_weights_memory = userI_weights_memory;
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
convI_weights_memory);
}
auto convI_dst_memory = userI_dst_memory;
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
convI_dst_memory);
}
convolution_backward_data(convI_prim_desc).execute(stream,
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
userI_src_memory);
}
stream.wait();
};
return Status::OK();
}
PLATFORM_CHECK(conv2d_bp) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
auto gradW = OUTPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
return block.isUseMKLDNN() &&
nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB});
}
}
}
}

View File

@ -0,0 +1,167 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(conv3dnew) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(
0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
REQUIRE_TRUE(input->rankOf() == 5, 0,
"CUSTOM CONV3D OP: rank of input array must be equal to 5, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0,
"CUSTOM CONV3D OP: rank of weights array must be equal to 5, but got %i instead !",
weights->rankOf());
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW =
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
"CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !",
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
"CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
oC, bias->rankOf(), bias->lengthOf());
if (isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
empty);
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
isNCDHW,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights,
nullptr, bias, output,
&conv_src_md, nullptr, &conv_weights_md, nullptr,
&conv_bias_md, &conv_dst_md,
&user_src_md, nullptr, &user_weights_md, nullptr,
&user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r);
auto conv_desc = bias != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto user_weights_memory = mkldnn::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto conv_src_memory = user_src_memory;
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine);
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
}
auto conv_weights_memory = user_weights_memory;
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine);
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
conv_weights_memory);
}
auto conv_dst_memory = user_dst_memory;
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine);
}
if (bias != nullptr) {
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine, bias->buffer());
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_BIAS, conv_bias_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
} else {
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
}
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
}
stream.wait();
return Status::OK();
}
PLATFORM_CHECK(conv3dnew) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(
0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output});
}
}
}
}

View File

@ -0,0 +1,263 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(conv3dnew_bp) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 5, 0,
"CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0,
"CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, but got %i instead !",
weights->rankOf());
REQUIRE_TRUE(gradO->rankOf() == 5, 0,
"CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !",
gradO->rankOf());
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
int isNDHWC =
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
int trueoD, trueoH, trueoW; // true output depth/height/width
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH,
dW, iD, iH, iW, isSameMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
"CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !",
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
"CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !",
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
"CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
oC, bias->rankOf(), bias->lengthOf());
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
isNDHWC,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights,
gradW, gradB, gradO,
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
&user_src_md, &user_diff_src_md, &user_weights_md,
&user_diff_weights_md, &user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r);
auto conv_desc = gradB != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
LaunchContext::defaultContext()->engine()));
if (gradW != nullptr) {
auto convW_desc = gradB != nullptr
? convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r)
: convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
conv_prim_desc);
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
const_cast<NDArray *>(input)->buffer());
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convW_src_memory = userW_src_memory;
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
convW_src_memory);
}
auto convW_weights_memory = userW_weights_memory;
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
}
auto convW_dst_memory = userW_dst_memory;
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
convW_dst_memory);
}
if (gradB != nullptr) {
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
gradB->buffer());
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
} else {
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
}
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
userW_weights_memory);
}
stream.wait();
}
if (gradI != nullptr) {
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto,
conv_diff_src_md, conv_weights_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
conv_prim_desc);
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convI_src_memory = userI_src_memory;
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
}
auto convI_weights_memory = userI_weights_memory;
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
convI_weights_memory);
}
auto convI_dst_memory = userI_dst_memory;
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
convI_dst_memory);
}
convolution_backward_data(convI_prim_desc).execute(stream,
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
userI_src_memory);
}
stream.wait();
}
return Status::OK();
}
PLATFORM_CHECK(conv3dnew_bp) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
return block.isUseMKLDNN() &&
nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB});
}
}
}
}

View File

@ -0,0 +1,97 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(lrn) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn: Input rank of 4 expected, but got %i instead",
input->rankOf());
double alpha = T_ARG(1);
double beta = T_ARG(2);
double bias = T_ARG(0);
int depth = INT_ARG(0);
mkldnn_memory_desc_t empty;
mkldnn::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty);
mkldnnUtils::getMKLDNNMemoryDescLrn(input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md,
&user_src_md, nullptr, &user_dst_md, input->rankOf() - 1);
auto lrn_desc = lrn_forward::desc(prop_kind::forward_inference, algorithm::lrn_across_channels,
lrn_src_md, (2 * depth + 1), alpha * (2 * depth + 1), beta, bias);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto lrn_src_memory = user_src_memory;
if (lrn_prim_desc.src_desc() != user_src_memory.get_desc()) {
lrn_src_memory = mkldnn::memory(lrn_prim_desc.src_desc(), engine);
reorder(user_src_memory, lrn_src_memory).execute(stream, user_src_memory, lrn_src_memory);
}
auto lrn_dst_memory = user_dst_memory;
if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
lrn_dst_memory = mkldnn::memory(lrn_prim_desc.dst_desc(), engine);
}
lrn_forward(lrn_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, lrn_src_memory},
{MKLDNN_ARG_DST, lrn_dst_memory}});
if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(lrn_dst_memory, user_dst_memory).execute(stream, lrn_dst_memory, user_dst_memory);
}
stream.wait();
return Status::OK();
};
PLATFORM_CHECK(lrn) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -0,0 +1,149 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(maxpool2d) {
auto input = INPUT_VARIABLE(0);
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
input->rankOf());
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
auto argI = *(block.getIArguments());
auto output = OUTPUT_VARIABLE(0);
const auto kH = INT_ARG(0);
const auto kW = INT_ARG(1);
const auto sH = INT_ARG(2);
const auto sW = INT_ARG(3);
int pH = INT_ARG(4);
int pW = INT_ARG(5);
const auto dH = INT_ARG(6);
const auto dW = INT_ARG(7);
const auto isSameMode = static_cast<bool>(INT_ARG(8));
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
dH, dW);
int oH = 0;
int oW = 0;
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
if (!isNCHW) {
input = new NDArray(
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = new NDArray(
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
if (isSameMode)
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
const int bS = input->sizeAt(0);
const int iC = input->sizeAt(1);
const int oC = output->sizeAt(1);
auto poolingMode = PoolingType::MAX_POOL;
int extraParam0 = 1;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
algorithm,
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
&user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory;
mkldnn::stream stream(engine);
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
}
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
}
stream.wait();
if (!isNCHW) {
delete input;
delete output;
}
return Status::OK();
}
PLATFORM_CHECK(maxpool2d) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -0,0 +1,178 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(maxpool2d_bp) {
auto input = INPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradO = INPUT_VARIABLE(
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int extraParam0 = INT_ARG(9);
int isNCHW =
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0,
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
indIiH, indWiC, indWoC, indWkH, indOoH);
std::string expectedGradOShape = ShapeUtils::shapeAsString(
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if (!isNCHW) {
input = new NDArray(input->permute(
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = new NDArray(gradI->permute(
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = new NDArray(gradO->permute(
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
if (isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
auto poolingMode = PoolingType::MAX_POOL;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
&user_diff_src_md, &user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
// input is sometimes null, so we can't rely on pool_src_md being valid
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
pool_dst_md, pool_strides, pool_kernel, pool_padding,
pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = mkldnn::memory(user_src_md, engine, gradI->buffer());
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory;
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
}
auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
}
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_desc(), engine);
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory},
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory}});
// probably wrong, fix that
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory},
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
}
stream.wait();
if (!isNCHW) {
delete input;
delete gradI;
delete gradO;
}
return Status::OK();
}
PLATFORM_CHECK(maxpool2d_bp) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -0,0 +1,155 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(maxpool3dnew) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto output = OUTPUT_VARIABLE(
0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
int kD = INT_ARG(0); // filter(kernel) depth
int kH = INT_ARG(1); // filter(kernel) height
int kW = INT_ARG(2); // filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
REQUIRE_TRUE(input->rankOf() == 5, 0,
"MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
"MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
if (!isNCDHW) {
input = new NDArray(
input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
output = new NDArray(
output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
}
if (isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
dW);
auto poolingMode = PoolingType::MAX_POOL;
auto extraParam0 = 1;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
algorithm,
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
&user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
pool_dst_md, pool_strides, pool_kernel, pool_padding,
pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
}
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
}
stream.wait();
if (!isNCDHW) {
delete input;
delete output;
}
return Status::OK();
}
PLATFORM_CHECK(maxpool3dnew) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -0,0 +1,185 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(maxpool3dnew_bp) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto gradO = INPUT_VARIABLE(
1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
const int kD = INT_ARG(0); // filter(kernel) depth
const int kH = INT_ARG(1); // filter(kernel) height
const int kW = INT_ARG(2); // filter(kernel) width
const int sD = INT_ARG(3); // strides depth
const int sH = INT_ARG(4); // strides height
const int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
const int dD = INT_ARG(9); // dilations depth
const int dH = INT_ARG(10); // dilations height
const int dW = INT_ARG(11); // dilations width
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
REQUIRE_TRUE(input->rankOf() == 5, 0,
"MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
"MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
"MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if (!isNCDHW) {
input = new NDArray(input->permute(
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = new NDArray(gradI->permute(
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradO = new NDArray(gradO->permute(
{0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
}
if (isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
dW);
auto poolingMode = PoolingType::MAX_POOL;
auto extraParam0 = 1;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
algorithm,
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
&user_diff_src_md, &user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
// input is sometimes null, so we can't rely on pool_src_md being valid
if (input->buffer() == nullptr) {
pool_src_md = pool_diff_src_md;
user_src_md = user_diff_src_md;
}
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory;
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
}
auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
}
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_desc(), engine);
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory},
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory}});
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory},
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
}
stream.wait();
if (!isNCDHW) {
delete input;
delete gradI;
delete gradO;
}
return Status::OK();
}
PLATFORM_CHECK(maxpool3dnew_bp) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -0,0 +1,404 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
//
#include <mkldnn_types.h>
#include "mkldnnUtils.h"
using namespace mkldnn;
namespace nd4j {
namespace mkldnnUtils {
void getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
mkldnn::memory::dims pool_src_tz = { bS, iC, iH, iW };
mkldnn::memory::dims pool_dst_tz = { bS, oC, oH, oW };
pool_strides = { sH, sW };
pool_kernel = { kH, kW };
pool_padding = { pH, pW };
pool_padding_r = { (oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
algorithm = poolingMode == 0 ? algorithm::pooling_max
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
: algorithm::pooling_avg_include_padding;
auto type = mkldnn::memory::data_type::f32;
auto format = isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
}
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
}
};
void getMKLDNNMemoryDescPool3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
mkldnn::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
mkldnn::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
pool_strides = { sD, sH, sW };
pool_kernel = { kD, kH, kW };
pool_padding = { pD, pH, pW };
pool_padding_r = { (oD - 1) * sD - iD + kD - pD,
(oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
algorithm = poolingMode == 0 ? algorithm::pooling_max
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
: algorithm::pooling_avg_include_padding;
auto type = mkldnn::memory::data_type::f32;
auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc;
auto supposed_to_be_any_format = mkldnn::memory::format_tag::nCdhw8c; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
}
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
}
};
void getMKLDNNMemoryDescConv2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r) {
mkldnn::memory::dims conv_src_tz = { bS, iC, iH, iW };
mkldnn::memory::dims conv_weights_tz = { oC, iC, kH, kW };
mkldnn::memory::dims conv_bias_tz = { oC };
mkldnn::memory::dims conv_dst_tz = { bS, oC, oH, oW };
conv_strides = { sH, sW };
conv_padding = { pH, pW };
conv_padding_r = { (oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
auto type = mkldnn::memory::data_type::f32;
auto format = isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
auto formatw = mkldnn::memory::format_tag::hwio;
if (src != nullptr && conv_src_md != nullptr) {
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
}
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
}
if (weights != nullptr && conv_weights_md != nullptr) {
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
user_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = hwio"
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3];
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2];
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
}
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
user_diff_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = hwio"
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3];
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2];
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
}
if (bias != nullptr && conv_bias_md != nullptr) {
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::any);
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::x);
}
if (dst != nullptr && conv_dst_md != nullptr) {
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format_tag::any);
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
}
}
void getMKLDNNMemoryDescConv3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r) {
mkldnn::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
mkldnn::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
mkldnn::memory::dims conv_bias_tz = { oC };
mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
conv_strides = { sD, sH, sW };
conv_padding = { pD, pH, pW };
conv_padding_r = { (oD - 1) * sD - iD + kD - pD,
(oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
auto type = mkldnn::memory::data_type::f32;
auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc;
auto formatw = mkldnn::memory::format_tag::dhwio;
if (src != nullptr && conv_src_md != nullptr) {
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
}
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
}
if (weights != nullptr && conv_weights_md != nullptr) {
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
user_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = dhwio"
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4];
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3];
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
user_weights_md->data.format_desc.blocking.strides[4] = weights->stridesOf()[2];
}
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
user_diff_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = dhwio"
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4];
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3];
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
user_diff_weights_md->data.format_desc.blocking.strides[4] = diff_weights->stridesOf()[2];
}
if (bias != nullptr && conv_bias_md != nullptr) {
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::any);
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::x);
}
if (dst != nullptr && conv_dst_md != nullptr) {
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format_tag::any);
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
}
};
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
const Nd4jLong* shape = src->getShapeInfo();
Nd4jLong rank = shape[0];
Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
Nd4jLong dim2 = axis >= 2 ? 1 : 2;
Nd4jLong dim3 = axis >= 3 ? 2 : 3;
mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
auto type = mkldnn::memory::data_type::f32;
auto format = mkldnn::memory::format_tag::nchw;
auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
*batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; // overrides format
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
*batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
}
if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
*batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; // overrides format
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
}
};
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* lrn_src_md, mkldnn::memory::desc* lrn_diff_src_md, mkldnn::memory::desc* lrn_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
const Nd4jLong* shape = src->getShapeInfo();
long rank = shape[0];
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
long dim2 = axis >= 2 ? 1 : 2;
long dim3 = axis >= 3 ? 2 : 3;
mkldnn::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
auto type = mkldnn::memory::data_type::f32;
auto format = axis == 1 ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
auto supposed_to_be_any_format = format; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
*lrn_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked;
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
*lrn_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked;
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
}
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
*lrn_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked;
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
}
}
mkldnn::engine& getEngine(void *ptr) {
auto eng = reinterpret_cast<mkldnn::engine*>(ptr);
return *eng;
}
}
}

View File

@ -0,0 +1,124 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
//
#ifndef DEV_TESTS_MKLDNNUTILS_H
#define DEV_TESTS_MKLDNNUTILS_H
#include <NativeOps.h>
#include <NDArray.h>
#include <mkldnn.hpp>
#include <MKLDNNStream.h>
#include <graph/Context.h>
#include <ops/declarable/PlatformHelper.h>
#include <platform_boilerplate.h>
namespace nd4j{
namespace ops {
namespace platforms {
/**
* Here we actually declare our platform helpers
*/
DECLARE_PLATFORM(conv2d);
DECLARE_PLATFORM(conv2d_bp);
DECLARE_PLATFORM(avgpool2d);
DECLARE_PLATFORM(avgpool2d_bp);
DECLARE_PLATFORM(maxpool2d);
DECLARE_PLATFORM(maxpool2d_bp);
DECLARE_PLATFORM(conv3dnew);
DECLARE_PLATFORM(conv3dnew_bp);
DECLARE_PLATFORM(maxpool3dnew);
DECLARE_PLATFORM(maxpool3dnew_bp);
DECLARE_PLATFORM(avgpool3dnew);
DECLARE_PLATFORM(avgpool3dnew_bp);
DECLARE_PLATFORM(lrn);
DECLARE_PLATFORM(batchnorm_new);
}
}
namespace mkldnnUtils {
/**
* Utility methods for MKLDNN
*/
void getMKLDNNMemoryDescConv2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
void getMKLDNNMemoryDescConv3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
void getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
void getMKLDNNMemoryDescPool3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis);
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* lrn_src_md, mkldnn::memory::desc* lrn_diff_src_md, mkldnn::memory::desc* lrn_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis);
mkldnn::engine& getEngine(void *ptr);
}
}
#endif //DEV_TESTS_MKLDNNUTILS_H

View File

@ -0,0 +1,45 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SD_PLATFORM_BOILERPLATE_H
#define SD_PLATFORM_BOILERPLATE_H
#define DECLARE_PLATFORM(NAME) class ND4J_EXPORT PLATFORM_##NAME : public PlatformHelper {\
public: \
PLATFORM_##NAME() : PlatformHelper(#NAME) { } \
bool isUsable(graph::Context &context) override; \
Nd4jStatus invokeHelper(graph::Context &context) override; \
};
#define PLATFORM_IMPL(NAME) struct ND4J_EXPORT __registratorPlatformHelper_##NAME { \
__registratorPlatformHelper_##NAME() { \
auto helper = new PLATFORM_##NAME(); \
OpRegistrator::getInstance()->registerHelper(helper); \
} \
}; \
static __registratorPlatformHelper_##NAME platformHelper_##NAME; \
Nd4jStatus PLATFORM_##NAME::invokeHelper(nd4j::graph::Context &block)
#define PLATFORM_CHECK(NAME) bool PLATFORM_##NAME::isUsable(graph::Context &block)
#endif //SD_PLATFORM_BOILERPLATE_H

View File

@ -21,7 +21,7 @@
#include <iosfwd> #include <iosfwd>
#include <iostream> #include <iostream>
#include <pointercast.h> #include <pointercast.h>
#if defined(__INTEL_COMPILER) || defined(__F16C__) #if defined(__INTEL_COMPILER) || defined(SD_F16C)
#include <immintrin.h> #include <immintrin.h>
#endif #endif
@ -122,7 +122,7 @@ static local_def unsigned short hneg(unsigned short h) {
} }
#if defined(__INTEL_COMPILER) || defined(__F16C__) #if defined(__INTEL_COMPILER) || defined(SD_F16C)
//_Pragma("omp declare simd") inline //_Pragma("omp declare simd") inline
local_def float cpu_ihalf2float(ihalf h) { local_def float cpu_ihalf2float(ihalf h) {
return _cvtsh_ss(h.getX()); return _cvtsh_ss(h.getX());
@ -157,7 +157,7 @@ local_def float cpu_ihalf2float(ihalf h) {
} }
#endif #endif
#if defined(__INTEL_COMPILER) || defined(__F16C__) #if defined(__INTEL_COMPILER) || defined(SD_F16C)
//_Pragma("omp declare simd") inline //_Pragma("omp declare simd") inline
local_def ihalf cpu_float2ihalf_rn(float f) { local_def ihalf cpu_float2ihalf_rn(float f) {
ihalf ret; ihalf ret;

View File

@ -74,6 +74,7 @@
<libnd4j.compute></libnd4j.compute> <libnd4j.compute></libnd4j.compute>
<libnd4j.classifier>${libnd4j.platform}</libnd4j.classifier> <libnd4j.classifier>${libnd4j.platform}</libnd4j.classifier>
<libnd4j.buildthreads></libnd4j.buildthreads> <libnd4j.buildthreads></libnd4j.buildthreads>
<libnd4j.helper></libnd4j.helper>
</properties> </properties>
<build> <build>
@ -175,6 +176,8 @@
<argument>${libnd4j.tests}</argument> <argument>${libnd4j.tests}</argument>
<argument>-j</argument> <argument>-j</argument>
<argument>${libnd4j.buildthreads}</argument> <argument>${libnd4j.buildthreads}</argument>
<argument>-h</argument>
<argument>${libnd4j.helper}</argument>
</buildCommand> </buildCommand>
<workingDirectory>${project.basedir}</workingDirectory> <workingDirectory>${project.basedir}</workingDirectory>
</configuration> </configuration>
@ -391,5 +394,30 @@
</plugins> </plugins>
</build> </build>
</profile> </profile>
<!-- Profiles to set the default libnd4j.helper property, example: mkdnn -->
<profile>
<id>libnd4j-helper-avx2</id>
<activation>
<property>
<name>libnd4j.extension</name>
<value>avx2</value>
</property>
</activation>
<properties>
<libnd4j.helper>mkldnn</libnd4j.helper>
</properties>
</profile>
<profile>
<id>libnd4j-helper-avx512</id>
<activation>
<property>
<name>libnd4j.extension</name>
<value>avx512</value>
</property>
</activation>
<properties>
<libnd4j.helper>mkldnn</libnd4j.helper>
</properties>
</profile>
</profiles> </profiles>
</project> </project>

View File

@ -131,7 +131,7 @@ endforeach(TMP_PATH)
if (CPU_BLAS) if (CPU_BLAS)
add_executable(runtests ${TEST_SOURCES}) add_executable(runtests ${TEST_SOURCES})
target_link_libraries(runtests ${LIBND4J_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} gtest gtest_main) target_link_libraries(runtests ${LIBND4J_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} gtest gtest_main)
elseif(CUDA_BLAS) elseif(CUDA_BLAS)
CUDA_ADD_EXECUTABLE(runtests ${TEST_SOURCES}) CUDA_ADD_EXECUTABLE(runtests ${TEST_SOURCES})
target_link_libraries(runtests ${LIBND4J_NAME} ${CUDA_LIBRARIES} gtest gtest_main) target_link_libraries(runtests ${LIBND4J_NAME} ${CUDA_LIBRARIES} gtest gtest_main)

View File

@ -32,6 +32,10 @@
#include <ops/declarable/helpers/col2im.h> #include <ops/declarable/helpers/col2im.h>
#include <PointersManager.h> #include <PointersManager.h>
#ifdef HAVE_MKLDNN
#include <ops/declarable/platform/mkldnn/mkldnnUtils.h>
#endif
using namespace nd4j; using namespace nd4j;
using namespace nd4j::graph; using namespace nd4j::graph;

View File

@ -0,0 +1,70 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <initializer_list>
#ifdef HAVE_MKLDNN
#include <ops/declarable/platform/mkldnn/mkldnnUtils.h>
#endif
class MklDnnTests : public testing::Test {
public:
};
static void printer(std::initializer_list<nd4j::ops::platforms::PlatformHelper*> helpers) {
for (auto v:helpers) {
nd4j_printf("Initialized [%s]\n", v->name().c_str());
}
}
TEST_F(MklDnnTests, helpers_includer) {
// we need this block, to make sure all helpers are still available within binary, and not optimized out by linker
#ifdef HAVE_MKLDNN
nd4j::ops::platforms::PLATFORM_conv2d conv2d;
nd4j::ops::platforms::PLATFORM_conv2d_bp conv2d_bp;
nd4j::ops::platforms::PLATFORM_conv2d conv3d;
nd4j::ops::platforms::PLATFORM_conv2d_bp conv3d_bp;
nd4j::ops::platforms::PLATFORM_avgpool2d avgpool2d;
nd4j::ops::platforms::PLATFORM_avgpool2d_bp avgpool2d_bp;
nd4j::ops::platforms::PLATFORM_maxpool2d maxpool2d;
nd4j::ops::platforms::PLATFORM_maxpool2d_bp maxpool2d_bp;
nd4j::ops::platforms::PLATFORM_avgpool3dnew avgpool3d;
nd4j::ops::platforms::PLATFORM_avgpool3dnew_bp avgpool3d_bp;
nd4j::ops::platforms::PLATFORM_maxpool3dnew maxpool3d;
nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp maxpool3d_bp;
nd4j::ops::platforms::PLATFORM_lrn lrn;
nd4j::ops::platforms::PLATFORM_batchnorm_new batchnorm;
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm});
#endif
}

View File

@ -21,7 +21,7 @@ endif()
# OPTIONAL MKL-DNN # OPTIONAL MKL-DNN
if ("${BUILD_MKLDNN}") if ("${BUILD_MKLDNN}")
# Download and unpack mkl-dnn at configure time # Download and unpack mkl-dnn at configure time
configure_file(./CMakeLists.txt.in mkldnn-download/CMakeLists.txt) configure_file(../../CMakeLists.txt.mkldnn.in mkldnn-download/CMakeLists.txt)
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
RESULT_VARIABLE result RESULT_VARIABLE result
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-download ) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-download )
@ -40,7 +40,7 @@ if ("${BUILD_MKLDNN}")
EXCLUDE_FROM_ALL) EXCLUDE_FROM_ALL)
set(mkldnn_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src) set(mkldnn_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src)
set(HAVE_MKLDNN 1) set(HAVE_MKLDNN 1)
add_definitions(-DHAVE_MKLDNN=true) add_definitions("-DHAVE_MKLDNN")
include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_SOURCE_DIR}/external/mklml_lnx_2019.0.3.20190220/include ${mkldnn_SOURCE_DIR}) include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_SOURCE_DIR}/external/mklml_lnx_2019.0.3.20190220/include ${mkldnn_SOURCE_DIR})
set(MKLDNN mkldnn) set(MKLDNN mkldnn)
endif() endif()
@ -131,7 +131,7 @@ else()
endif() endif()
if (${F16C}) if (${F16C})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c -D__F16C__=true") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c -DSD_F16C=true")
endif() endif()
endif() endif()
@ -177,6 +177,7 @@ if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}
message(FATAL_ERROR "You need at least GCC 4.9") message(FATAL_ERROR "You need at least GCC 4.9")
endif() endif()
message("Looking for OpenMP")
find_package(OpenMP) find_package(OpenMP)
if (OPENMP_FOUND) if (OPENMP_FOUND)
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
@ -185,10 +186,11 @@ else()
message("OPENMP NOT FOUND") message("OPENMP NOT FOUND")
endif() endif()
if ("${OPENBLAS}" OR CMAKE_BUILD_TYPE STREQUAL "Release") if ("${OPENBLAS}" OR CMAKE_BUILD_TYPE STREQUAL "Release" OR "${BUILD_MKLDNN}")
find_package(BLAS) message("Looking for BLAS")
find_package(BLAS REQUIRED)
if (BLAS_FOUND) if (BLAS_FOUND)
message("Found external BLAS implementation...") message("Found external BLAS library: ${BLAS_LIBRARIES}")
add_definitions(-D__EXTERNAL_BLAS__=true) add_definitions(-D__EXTERNAL_BLAS__=true)
endif() endif()
endif() endif()
@ -201,13 +203,18 @@ file(GLOB_RECURSE ARRAY_SOURCES false ../../include/array/*.cpp ../../include/ar
file(GLOB_RECURSE MEMORY_SOURCES false ../../include/memory/*.cpp ../../include/memory/*.h) file(GLOB_RECURSE MEMORY_SOURCES false ../../include/memory/*.cpp ../../include/memory/*.h)
file(GLOB_RECURSE GRAPH_SOURCES false ../../include/graph/*.cpp ../../include/graph/*.h) file(GLOB_RECURSE GRAPH_SOURCES false ../../include/graph/*.cpp ../../include/graph/*.h)
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../../include/ops/declarable/generic/*.cpp) file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../../include/ops/declarable/generic/*.cpp)
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../../include/ops/declarable/helpers/cpu/*.cpp ../../include/ops/declarable/helpers/impl/*.cpp) file(GLOB_RECURSE CUSTOMOPS_GENERIC_SOURCES false ../../include/ops/declarable/helpers/cpu/*.cpp ../../include/ops/declarable/helpers/impl/*.cpp)
file(GLOB_RECURSE OPS_SOURCES false ../../include/ops/impl/*.cpp ../../include/ops/declarable/impl/*.cpp ../../include/ops/*.h) file(GLOB_RECURSE OPS_SOURCES false ../../include/ops/impl/*.cpp ../../include/ops/declarable/impl/*.cpp ../../include/ops/*.h)
file(GLOB_RECURSE INDEXING_SOURCES false ../../include/indexing/*.cpp ../../include/indexing/*.h) file(GLOB_RECURSE INDEXING_SOURCES false ../../include/indexing/*.cpp ../../include/indexing/*.h)
file(GLOB_RECURSE HELPERS_SOURCES false ../../include/helpers/*.cpp ../../include/helpers/*.h) file(GLOB_RECURSE HELPERS_SOURCES false ../../include/helpers/*.cpp)
file(GLOB_RECURSE LOOPS_SOURCES false ../../include/loops/*.cpp ../../include/loops/*.h) file(GLOB_RECURSE LOOPS_SOURCES false ../../include/loops/*.cpp ../../include/loops/*.h)
message("CPU BLAS") # optionally build mkldnn
if ("${BUILD_MKLDNN}")
file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../../include/ops/declarable/platform/mkldnn/*.cpp)
endif()
message("CPU backend")
add_definitions(-D__CPUBLAS__=true) add_definitions(-D__CPUBLAS__=true)
if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW) AND NOT(APPLE)) if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW) AND NOT(APPLE))
@ -216,8 +223,37 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW) AND NOT(APPLE))
endif() endif()
# this function strips path from file name, basically making up short file name, i.e. file.cpp
function(SHORTNAME LONG_NAME OUTPUT)
SET(_TMP_STR "")
string (REGEX REPLACE ".*/" "" _TMP_STR "${LONG_NAME}")
set (${OUTPUT} "${_TMP_STR}" PARENT_SCOPE)
endfunction()
# now we ned to join two lists
# first of all we'll build truncated list of files in platform sources
# and list of priority implementations from platform helpers
#set(CUSTOMOPS_HELPERS_SOURCES "")
#set(SHORT_NAMES "")
#foreach(LONG_NAME ${CUSTOMOPS_PLATFORM_SOURCES})
# SHORTNAME("${LONG_NAME}" "SHORT_NAME")
# set(CUSTOMOPS_HELPERS_SOURCES ${CUSTOMOPS_HELPERS_SOURCES} ${LONG_NAME})
# set(SHORT_NAMES ${SHORT_NAMES} ${SHORT_NAME})
#endforeach()
# now we're going to filter generic helpers, to exclude platform implementations
#foreach(LONG_NAME ${CUSTOMOPS_GENERIC_SOURCES})
# SHORTNAME("${LONG_NAME}" "SHORT_NAME")
# and now we add this op ONLY if it wasn't announced in platform helpers
# string(FIND "${SHORT_NAMES}" "${SHORT_NAME}" "LOC")
# if (${LOC} EQUAL -1)
# set(CUSTOMOPS_HELPERS_SOURCES ${CUSTOMOPS_HELPERS_SOURCES} ${LONG_NAME})
# endif()
#endforeach()
file(GLOB_RECURSE TEST_SOURCES false ../layers_tests/*.cpp ../layers_tests/*.h) file(GLOB_RECURSE TEST_SOURCES false ../layers_tests/*.cpp ../layers_tests/*.h)
# file(GLOB_RECURSE TEST_SOURCES false ../layers_tests/DeclarableOpsTests6.cpp ../layers_tests/*.h)
# Filter out any source files from */CMakeFiles/* paths. these tend to cause problems such a multiple main definitions. # Filter out any source files from */CMakeFiles/* paths. these tend to cause problems such a multiple main definitions.
@ -234,7 +270,7 @@ add_executable(runtests ${LOOPS_SOURCES} ../../blas/cpu/NativeOps.cpp ../../blas
../../blas/cpu/NativeOpExecutioner.cpp ../../blas/cpu/NDArray.cpp ../../blas/cpu/NDArrayFactory.cpp ../../blas/cpu/NativeOpExecutioner.cpp ../../blas/cpu/NDArray.cpp ../../blas/cpu/NDArrayFactory.cpp
../../include/cnpy/cnpy.cpp ../../include/nd4jmemset.h ../../include/nd4jmalloc.h ../../include/cnpy/cnpy.cpp ../../include/nd4jmemset.h ../../include/nd4jmalloc.h
../../blas/Environment.cpp ../../blas/Environment.h ${EXEC_SOURCES} ${HELPERS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} ../../blas/Environment.cpp ../../blas/Environment.h ${EXEC_SOURCES} ${HELPERS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_HELPERS_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
${OPS_SOURCES} ${TEST_SOURCES} ${PERF_SOURCES}) ${OPS_SOURCES} ${TEST_SOURCES} ${PERF_SOURCES})
target_link_libraries(runtests gtest ${MKLDNN} gtest_main ${BLAS_LIBRARIES}) target_link_libraries(runtests gtest ${MKLDNN} gtest_main ${BLAS_LIBRARIES})

View File

@ -1153,4 +1153,10 @@ public interface NativeOps {
String lastErrorMessage(); String lastErrorMessage();
boolean isBlasVersionMatches(int major, int minor, int build); boolean isBlasVersionMatches(int major, int minor, int build);
int binaryLevel();
int optimalLevel();
boolean isMinimalRequirementsMet();
boolean isOptimalRequirementsMet();
} }

View File

@ -68,6 +68,7 @@ import org.bytedeco.javacpp.tools.InfoMapper;
//"op_boilerplate.h", //"op_boilerplate.h",
"ops/InputType.h", "ops/InputType.h",
"ops/declarable/OpDescriptor.h", "ops/declarable/OpDescriptor.h",
"ops/declarable/PlatformHelper.h",
"ops/declarable/BroadcastableOp.h", "ops/declarable/BroadcastableOp.h",
"helpers/OpArgsHolder.h", "helpers/OpArgsHolder.h",
"ops/declarable/DeclarableOp.h", "ops/declarable/DeclarableOp.h",

Some files were not shown because too many files have changed in this diff Show More