- MKL-DNN version upgrade to 1.1.x (#62)

- MKL-DNN namespace changes to match DNNL rename

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-20 13:23:08 +03:00 committed by GitHub
parent 7898f3c0cc
commit 59e955cedc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 836 additions and 836 deletions

View File

@ -51,7 +51,7 @@ endif()
if(NOT CUDA_BLAS) if(NOT CUDA_BLAS)
# we need this definition to avoid global memory use within mkldnn # we need this definition to avoid global memory use within mkldnn
add_definitions(-DMKLDNN_ENABLE_CONCURRENT_EXEC=true) add_definitions(-DDNNL_ENABLE_CONCURRENT_EXEC=true)
# there's a chance, we have no BLAS provided externally # there's a chance, we have no BLAS provided externally
if ("${OPENBLAS_PATH}" STREQUAL "") if ("${OPENBLAS_PATH}" STREQUAL "")
@ -122,7 +122,7 @@ if(NOT CUDA_BLAS)
if (${HELPERS_mkldnn}) if (${HELPERS_mkldnn})
message("Going to pull & build mkldnn") message("Going to pull & build mkldnn")
set(HAVE_MKLDNN 1) set(HAVE_MKLDNN 1)
set(MKLDNN_LIBRARY_TYPE "STATIC" CACHE STRING "Hack to enforce static mode" FORCE) set(DNNL_LIBRARY_TYPE "STATIC" CACHE STRING "Hack to enforce static mode" FORCE)
configure_file(./CMakeLists.txt.mkldnn.in mkldnn-download/CMakeLists.txt) configure_file(./CMakeLists.txt.mkldnn.in mkldnn-download/CMakeLists.txt)
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
@ -146,7 +146,7 @@ if(NOT CUDA_BLAS)
set(mkldnn_EXT_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src) set(mkldnn_EXT_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src)
set(MKLDNN_PATH "${mkldnn_SOURCE_DIR}") set(MKLDNN_PATH "${mkldnn_SOURCE_DIR}")
include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_EXT_DIR}/include ${mkldnn_SOURCE_DIR}) include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_EXT_DIR}/include ${mkldnn_SOURCE_DIR})
set(MKLDNN mkldnn) set(MKLDNN dnnl)
endif() endif()
endif() endif()

View File

@ -5,11 +5,11 @@ project(mkldnn-download NONE)
include(ExternalProject) include(ExternalProject)
ExternalProject_Add(mkldnn ExternalProject_Add(mkldnn
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
GIT_TAG v1.0.4 GIT_TAG v1.1.1
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src" SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
CMAKE_ARGS -DMKLDNN_USE_MKL=ML -DMKLDNN_LIBRARY_TYPE=STATIC -G \"Unix Makefiles\" CMAKE_ARGS -DDNNL_USE_MKL=ML -DDNNL_LIBRARY_TYPE=STATIC -G \"Unix Makefiles\"
BUILD_COMMAND "" BUILD_COMMAND ""
INSTALL_COMMAND "" INSTALL_COMMAND ""
TEST_COMMAND "" TEST_COMMAND ""

View File

@ -30,14 +30,14 @@ thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
#endif #endif
#ifdef HAVE_MKLDNN #ifdef HAVE_MKLDNN
#include <mkldnn.hpp> #include <dnnl.hpp>
#endif #endif
namespace nd4j { namespace nd4j {
LaunchContext::~LaunchContext() { LaunchContext::~LaunchContext() {
#ifdef HAVE_MKLDNN #ifdef HAVE_MKLDNN
delete reinterpret_cast<mkldnn::engine*>(_engine); delete reinterpret_cast<dnnl::engine*>(_engine);
#endif #endif
} }
@ -50,7 +50,7 @@ namespace nd4j {
_deviceID = 0; _deviceID = 0;
#ifdef HAVE_MKLDNN #ifdef HAVE_MKLDNN
_engine = new mkldnn::engine(mkldnn::engine::kind::cpu, 0); _engine = new dnnl::engine(dnnl::engine::kind::cpu, 0);
#endif #endif
} }

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h> #include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn; using namespace dnnl;
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -82,11 +82,11 @@ namespace nd4j {
auto poolingMode = PoolingType::AVG_POOL; auto poolingMode = PoolingType::AVG_POOL;
mkldnn_memory_desc_t empty; dnnl_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty); dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty); dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm; dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true, true,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output, bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
@ -99,20 +99,20 @@ namespace nd4j {
pool_strides, pool_kernel, pool_padding, pool_padding_r); pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory; auto pool_src_memory = user_src_memory;
mkldnn::stream stream(engine); dnnl::stream stream(engine);
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine); pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
} }
auto pool_dst_memory = user_dst_memory; auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine); pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
} }
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory}, pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory}}); {DNNL_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
} }

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h> #include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn; using namespace dnnl;
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -89,11 +89,11 @@ namespace nd4j {
auto poolingMode = PoolingType::AVG_POOL; auto poolingMode = PoolingType::AVG_POOL;
mkldnn_memory_desc_t empty; dnnl_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm; dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true, true,
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm, bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
@ -109,20 +109,20 @@ namespace nd4j {
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
pool_kernel, pool_padding, pool_padding_r); pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = mkldnn::memory(user_src_md, engine, gradI->buffer()); auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer()); auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory; auto poolB_src_memory = userB_src_memory;
mkldnn::stream stream(engine); dnnl::stream stream(engine);
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine); poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
} }
auto poolB_dst_memory = userB_dst_memory; auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine); poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
} }
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory}, pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}}); {DNNL_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
} }

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h> #include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn; using namespace dnnl;
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -86,11 +86,11 @@ namespace nd4j {
auto poolingMode = PoolingType::AVG_POOL; auto poolingMode = PoolingType::AVG_POOL;
mkldnn_memory_desc_t empty; dnnl_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty); dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty); dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm; dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true, extraParam0, true,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
@ -102,21 +102,21 @@ namespace nd4j {
pool_dst_md, pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r); pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine); dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory; auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine); pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
} }
auto pool_dst_memory = user_dst_memory; auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine); pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
} }
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory}, pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory}}); {DNNL_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
} }

View File

@ -26,7 +26,7 @@
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h> #include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn; using namespace dnnl;
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -92,11 +92,11 @@ namespace nd4j {
auto poolingMode = PoolingType::AVG_POOL; auto poolingMode = PoolingType::AVG_POOL;
mkldnn_memory_desc_t empty; dnnl_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm; dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true, extraParam0, true,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
@ -111,24 +111,24 @@ namespace nd4j {
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r); pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine); dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
pool_kernel, pool_padding, pool_padding_r); pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer()); auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer()); auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory; auto poolB_src_memory = userB_src_memory;
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine); poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
} }
auto poolB_dst_memory = userB_dst_memory; auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine); poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
} }
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory}, pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}}); {DNNL_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
} }

View File

@ -39,7 +39,7 @@ namespace platforms {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, const float epsilon, NDArray* z) { static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, const float epsilon, NDArray* z) {
// unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any) // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any)
// also it gives wrong results for formats nhwc and ndhwc // also it gives wrong results for formats nhwc and ndhwc
// x -> 2D:nc, 4D:nchw, 5D:ncdhw // x -> 2D:nc, 4D:nchw, 5D:ncdhw
@ -53,35 +53,35 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// input type // input type
mkldnn::memory::data_type type = mkldnn::memory::data_type::f32; dnnl::memory::data_type type = dnnl::memory::data_type::f32;
// indicate whether gamma or/and beta are given // indicate whether gamma or/and beta are given
auto flags = mkldnn::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch auto flags = dnnl::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch
if (weights != nullptr) if (weights != nullptr)
flags |= mkldnn::normalization_flags::use_scale_shift; flags |= dnnl::normalization_flags::use_scale_shift;
mkldnn::memory::dims dims; dnnl::memory::dims dims;
mkldnn::memory::format_tag format; dnnl::memory::format_tag format;
if(xRank == 2) { if(xRank == 2) {
dims = {x->sizeAt(0), x->sizeAt(1)}; dims = {x->sizeAt(0), x->sizeAt(1)};
format = mkldnn::memory::format_tag::nc; format = dnnl::memory::format_tag::nc;
} }
else if(xRank == 4) { else if(xRank == 4) {
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)}; dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
format = mkldnn::memory::format_tag::nchw; format = dnnl::memory::format_tag::nchw;
} }
else { // xRank = 5 else { // xRank = 5
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)}; dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
format = mkldnn::memory::format_tag::ncdhw; format = dnnl::memory::format_tag::ncdhw;
} }
// memory descriptors for arrays // memory descriptors for arrays
// x // x
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
if(xRank > 2) { if(xRank > 2) {
@ -92,9 +92,9 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4]; x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
// z, output // z, output
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(dims, type, format); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, format);
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(dims, type, format); dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
z_user_md.data.format_kind = mkldnn_blocked; // overrides format z_user_md.data.format_kind = dnnl_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = z->stridesOf()[0]; z_user_md.data.format_desc.blocking.strides[0] = z->stridesOf()[0];
z_user_md.data.format_desc.blocking.strides[1] = z->stridesOf()[1]; z_user_md.data.format_desc.blocking.strides[1] = z->stridesOf()[1];
if(xRank > 2) { if(xRank > 2) {
@ -106,53 +106,53 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
// batchnorm forward description // batchnorm forward description
mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags); dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// arguments (memory buffers) necessary for calculations // arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args; std::unordered_map<int, dnnl::memory> args;
mkldnn::stream stream(engine); dnnl::stream stream(engine);
// provide memory and check whether reorder is required // provide memory and check whether reorder is required
// x // x
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer()); auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc(); const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem; auto x_mkl_mem = xReorder ? dnnl::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder) if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem; args[DNNL_ARG_SRC] = x_mkl_mem;
// z // z
auto z_user_mem = mkldnn::memory(z_user_md, engine, z->getBuffer()); auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc(); const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? mkldnn::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem; auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem;
if (zReorder) if (zReorder)
mkldnn::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem); dnnl::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem);
args[MKLDNN_ARG_DST] = z_mkl_mem; args[DNNL_ARG_DST] = z_mkl_mem;
// mean // mean
auto mean_mkl_mem = mkldnn::memory(op_ff_prim_desc.mean_desc(), engine, mean->getBuffer()); auto mean_mkl_mem = dnnl::memory(op_ff_prim_desc.mean_desc(), engine, mean->getBuffer());
args[MKLDNN_ARG_MEAN] = mean_mkl_mem; args[DNNL_ARG_MEAN] = mean_mkl_mem;
// variance // variance
auto var_mkl_mem = mkldnn::memory(op_ff_prim_desc.variance_desc(), engine, variance->getBuffer()); auto var_mkl_mem = dnnl::memory(op_ff_prim_desc.variance_desc(), engine, variance->getBuffer());
args[MKLDNN_ARG_VARIANCE] = var_mkl_mem; args[DNNL_ARG_VARIANCE] = var_mkl_mem;
// gamma and beta (and their gradients) if they are present // gamma and beta (and their gradients) if they are present
if(weights != nullptr) { if(weights != nullptr) {
auto w_mkl_mem = mkldnn::memory(op_ff_prim_desc.weights_desc(), engine, weights->getBuffer()); auto w_mkl_mem = dnnl::memory(op_ff_prim_desc.weights_desc(), engine, weights->getBuffer());
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
} }
// run calculations // run calculations
mkldnn::batch_normalization_forward(op_ff_prim_desc).execute(stream, args); dnnl::batch_normalization_forward(op_ff_prim_desc).execute(stream, args);
// reorder outputs if necessary // reorder outputs if necessary
if (zReorder) if (zReorder)
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
stream.wait(); stream.wait();
@ -164,7 +164,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights, static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights,
const float epsilon, NDArray* dLdI, NDArray* dLdW) { const float epsilon, NDArray* dLdI, NDArray* dLdW) {
// unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any) // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any)
// also it gives wrong results for formats nhwc and ndhwc // also it gives wrong results for formats nhwc and ndhwc
// x -> 2D:nc, 4D:nchw, 5D:ncdhw // x -> 2D:nc, 4D:nchw, 5D:ncdhw
@ -180,35 +180,35 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// input type // input type
mkldnn::memory::data_type type = mkldnn::memory::data_type::f32; dnnl::memory::data_type type = dnnl::memory::data_type::f32;
// indicate whether gamma or/and beta are given // indicate whether gamma or/and beta are given
auto flags = mkldnn::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch auto flags = dnnl::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch
if (weights != nullptr) if (weights != nullptr)
flags |= mkldnn::normalization_flags::use_scale_shift; flags |= dnnl::normalization_flags::use_scale_shift;
mkldnn::memory::dims dims; dnnl::memory::dims dims;
mkldnn::memory::format_tag format; dnnl::memory::format_tag format;
if(xRank == 2) { if(xRank == 2) {
dims = {x->sizeAt(0), x->sizeAt(1)}; dims = {x->sizeAt(0), x->sizeAt(1)};
format = mkldnn::memory::format_tag::nc; format = dnnl::memory::format_tag::nc;
} }
else if(xRank == 4) { else if(xRank == 4) {
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)}; dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
format = mkldnn::memory::format_tag::nchw; format = dnnl::memory::format_tag::nchw;
} }
else { // xRank = 5 else { // xRank = 5
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)}; dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
format = mkldnn::memory::format_tag::ncdhw; format = dnnl::memory::format_tag::ncdhw;
} }
// memory descriptors for arrays // memory descriptors for arrays
// x // x
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
if(xRank > 2) { if(xRank > 2) {
@ -219,9 +219,9 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4]; x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
// dLdO // dLdO
mkldnn::memory::desc dLdO_mkl_md = mkldnn::memory::desc(dims, type, format); dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, format);
mkldnn::memory::desc dLdO_user_md = mkldnn::memory::desc(dims, type, format); dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
dLdO_user_md.data.format_kind = mkldnn_blocked; // overrides format dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format
dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->stridesOf()[0]; dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->stridesOf()[0];
dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->stridesOf()[1]; dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->stridesOf()[1];
if(xRank > 2) { if(xRank > 2) {
@ -232,9 +232,9 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->stridesOf()[4]; dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->stridesOf()[4];
// dLdI // dLdI
mkldnn::memory::desc dLdI_mkl_md = mkldnn::memory::desc(dims, type, format); dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, format);
mkldnn::memory::desc dLdI_user_md = mkldnn::memory::desc(dims, type, format); dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
dLdI_user_md.data.format_kind = mkldnn_blocked; // overrides format dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format
dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->stridesOf()[0]; dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->stridesOf()[0];
dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->stridesOf()[1]; dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->stridesOf()[1];
if(xRank > 2) { if(xRank > 2) {
@ -245,66 +245,66 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->stridesOf()[4]; dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->stridesOf()[4];
// batchnorm forward description // batchnorm forward description
mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags); dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// batchnorm backprop description // batchnorm backprop description
mkldnn::batch_normalization_backward::desc op_bp_desc(mkldnn::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags); dnnl::batch_normalization_backward::desc op_bp_desc(dnnl::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags);
mkldnn::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc); dnnl::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc);
// arguments (memory buffers) necessary for calculations // arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args; std::unordered_map<int, dnnl::memory> args;
mkldnn::stream stream(engine); dnnl::stream stream(engine);
// provide memory and check whether reorder is required // provide memory and check whether reorder is required
// x // x
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer()); auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
const bool xReorder = op_bp_prim_desc.src_desc() != x_user_mem.get_desc(); const bool xReorder = op_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem; auto x_mkl_mem = xReorder ? dnnl::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder) if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem; args[DNNL_ARG_SRC] = x_mkl_mem;
// dLdO // dLdO
auto dLdO_user_mem = mkldnn::memory(dLdO_user_md, engine, dLdO->getBuffer()); auto dLdO_user_mem = dnnl::memory(dLdO_user_md, engine, dLdO->getBuffer());
const bool dLdOReorder = op_bp_prim_desc.diff_dst_desc() != dLdO_user_mem.get_desc(); const bool dLdOReorder = op_bp_prim_desc.diff_dst_desc() != dLdO_user_mem.get_desc();
auto dLdO_mkl_mem = dLdOReorder ? mkldnn::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdO_user_mem; auto dLdO_mkl_mem = dLdOReorder ? dnnl::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdO_user_mem;
if (dLdOReorder) if (dLdOReorder)
mkldnn::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem); dnnl::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem);
args[MKLDNN_ARG_DIFF_DST] = dLdO_mkl_mem; args[DNNL_ARG_DIFF_DST] = dLdO_mkl_mem;
// mean // mean
auto mean_mkl_mem = mkldnn::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());
args[MKLDNN_ARG_MEAN] = mean_mkl_mem; args[DNNL_ARG_MEAN] = mean_mkl_mem;
// variance // variance
auto var_mkl_mem = mkldnn::memory(op_bp_prim_desc.variance_desc(), engine, variance->getBuffer()); auto var_mkl_mem = dnnl::memory(op_bp_prim_desc.variance_desc(), engine, variance->getBuffer());
args[MKLDNN_ARG_VARIANCE] = var_mkl_mem; args[DNNL_ARG_VARIANCE] = var_mkl_mem;
// dLdI // dLdI
auto dLdI_user_mem = mkldnn::memory(dLdI_user_md, engine, dLdI->getBuffer()); auto dLdI_user_mem = dnnl::memory(dLdI_user_md, engine, dLdI->getBuffer());
const bool dLdIReorder = op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc(); const bool dLdIReorder = op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc();
auto dLdI_mkl_mem = dLdIReorder ? mkldnn::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem; auto dLdI_mkl_mem = dLdIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem;
args[MKLDNN_ARG_DIFF_SRC] = dLdI_mkl_mem; args[DNNL_ARG_DIFF_SRC] = dLdI_mkl_mem;
// gamma and beta (and their gradients) if they are present // gamma and beta (and their gradients) if they are present
if(weights != nullptr) { if(weights != nullptr) {
auto w_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, weights->getBuffer()); auto w_mkl_mem = dnnl::memory(op_bp_prim_desc.weights_desc(), engine, weights->getBuffer());
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
auto dLdW_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->getBuffer()); auto dLdW_mkl_mem = dnnl::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->getBuffer());
args[MKLDNN_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem; args[DNNL_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem;
} }
// run calculations // run calculations
mkldnn::batch_normalization_backward(op_bp_prim_desc).execute(stream, args); dnnl::batch_normalization_backward(op_bp_prim_desc).execute(stream, args);
// reorder outputs if necessary // reorder outputs if necessary
if (dLdIReorder) if (dLdIReorder)
mkldnn::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem); dnnl::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem);
stream.wait(); stream.wait();
@ -532,37 +532,37 @@ PLATFORM_CHECK(batchnorm) {
// weights({1, 2, 0, 0}).assign(0.0f); // weights({1, 2, 0, 0}).assign(0.0f);
// mkldnn_memory_desc_t empty; // mkldnn_memory_desc_t empty;
// mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty); // dnnl::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty);
// auto flag = mkldnn::normalization_flags::use_global_stats; // auto flag = dnnl::normalization_flags::use_global_stats;
// if (applyScale || applyOffset) // if (applyScale || applyOffset)
// flag |= mkldnn::normalization_flags::use_scale_shift; // flag |= dnnl::normalization_flags::use_scale_shift;
// mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output, // mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output,
// &batchnorm_src_md, nullptr, &batchnorm_dst_md, // &batchnorm_src_md, nullptr, &batchnorm_dst_md,
// &user_src_md, nullptr, &user_dst_md, axes[0]); // &user_src_md, nullptr, &user_dst_md, axes[0]);
// auto batchnorm_desc = mkldnn::batch_normalization_forward::desc(mkldnn::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag); // auto batchnorm_desc = dnnl::batch_normalization_forward::desc(dnnl::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag);
// auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); // auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// mkldnn::stream stream(engine); // dnnl::stream stream(engine);
// auto batchnorm_prim_desc = mkldnn::batch_normalization_forward::primitive_desc(batchnorm_desc, engine); // auto batchnorm_prim_desc = dnnl::batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
// auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); // auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
// auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); // auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
// auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine, // auto batchnorm_mean_memory = dnnl::memory(batchnorm_prim_desc.mean_desc(), engine,
// mean->buffer()); // mean->buffer());
// auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine, // auto batchnorm_variance_memory = dnnl::memory(batchnorm_prim_desc.variance_desc(), engine,
// variance->buffer()); // variance->buffer());
// auto batchnorm_src_memory = user_src_memory; // auto batchnorm_src_memory = user_src_memory;
// mkldnn::memory m(batchnorm_src_md, engine); // dnnl::memory m(batchnorm_src_md, engine);
// if (m.get_desc() != user_src_memory.get_desc()) { // if (m.get_desc() != user_src_memory.get_desc()) {
// batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine); // batchnorm_src_memory = dnnl::memory(batchnorm_src_md, engine);
// mkldnn::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory, // dnnl::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory,
// batchnorm_src_memory); // batchnorm_src_memory);
// } // }
// auto batchnorm_dst_memory = user_dst_memory; // auto batchnorm_dst_memory = user_dst_memory;
// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { // if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
// batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine); // batchnorm_dst_memory = dnnl::memory(batchnorm_prim_desc.dst_desc(), engine);
// } // }
// if (applyScale || applyOffset) { // if (applyScale || applyOffset) {
// if (gamma != nullptr) { // if (gamma != nullptr) {
@ -572,22 +572,22 @@ PLATFORM_CHECK(batchnorm) {
// weights({1, 2, 0, 0}).assign(beta); // weights({1, 2, 0, 0}).assign(beta);
// } // }
// auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer()); // auto batchnorm_weights_memory = dnnl::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer());
// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream, // dnnl::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
// {{MKLDNN_ARG_SRC, batchnorm_src_memory}, // {{MKLDNN_ARG_SRC, batchnorm_src_memory},
// {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, // {MKLDNN_ARG_MEAN, batchnorm_mean_memory},
// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, // {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
// {MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory}, // {MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory},
// {MKLDNN_ARG_DST, batchnorm_dst_memory}}); // {MKLDNN_ARG_DST, batchnorm_dst_memory}});
// } else { // } else {
// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream, // dnnl::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
// {{MKLDNN_ARG_SRC, batchnorm_src_memory}, // {{MKLDNN_ARG_SRC, batchnorm_src_memory},
// {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, // {MKLDNN_ARG_MEAN, batchnorm_mean_memory},
// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, // {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
// {MKLDNN_ARG_DST, batchnorm_dst_memory}}); // {MKLDNN_ARG_DST, batchnorm_dst_memory}});
// } // }
// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { // if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
// mkldnn::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory, // dnnl::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory,
// user_dst_memory); // user_dst_memory);
// } // }
// stream.wait(); // stream.wait();

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h> #include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn; using namespace dnnl;
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -47,12 +47,12 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
if(isSameMode) // SAME if(isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
mkldnn_memory_desc_t empty; dnnl_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
empty); empty);
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
empty); empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
bias, output, bias, output,
@ -74,38 +74,38 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r); conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine); dnnl::stream stream(engine);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine); auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer()); auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto user_weights_memory = mkldnn::memory(user_weights_md, engine, auto user_weights_memory = dnnl::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer()); const_cast<NDArray *>(weights)->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto conv_src_memory = user_src_memory; auto conv_src_memory = user_src_memory;
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) { if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine); conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine);
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory); reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
} }
auto conv_weights_memory = user_weights_memory; auto conv_weights_memory = user_weights_memory;
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) { if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine); conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine);
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
conv_weights_memory); conv_weights_memory);
} }
auto conv_dst_memory = user_dst_memory; auto conv_dst_memory = user_dst_memory;
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine); conv_dst_memory = dnnl::memory(conv_prim_desc.dst_desc(), engine);
} }
if (bias != nullptr) { if (bias != nullptr) {
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine, auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine,
const_cast<NDArray *>(bias)->buffer()); const_cast<NDArray *>(bias)->buffer());
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory}, convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_BIAS, conv_bias_memory}, {DNNL_ARG_BIAS, conv_bias_memory},
{MKLDNN_ARG_DST, conv_dst_memory}}); {DNNL_ARG_DST, conv_dst_memory}});
} else { } else {
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory}, convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_DST, conv_dst_memory}}); {DNNL_ARG_DST, conv_dst_memory}});
} }
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory); reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
@ -198,12 +198,12 @@ PLATFORM_IMPL(conv2d_bp) {
if (isSameMode) // SAME if (isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
mkldnn_memory_desc_t empty; dnnl_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
gradB, gradO, gradB, gradO,
@ -235,47 +235,47 @@ PLATFORM_IMPL(conv2d_bp) {
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine); dnnl::stream stream(engine);
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
conv_prim_desc); conv_prim_desc);
auto userW_src_memory = mkldnn::memory(user_src_md, engine, auto userW_src_memory = dnnl::memory(user_src_md, engine,
const_cast<NDArray *>(input)->buffer()); const_cast<NDArray *>(input)->buffer());
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer()); auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine, auto userW_dst_memory = dnnl::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer()); const_cast<NDArray *>(gradO)->buffer());
auto convW_src_memory = userW_src_memory; auto convW_src_memory = userW_src_memory;
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) { if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine); convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
convW_src_memory); convW_src_memory);
} }
auto convW_weights_memory = userW_weights_memory; auto convW_weights_memory = userW_weights_memory;
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine); convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine);
} }
auto convW_dst_memory = userW_dst_memory; auto convW_dst_memory = userW_dst_memory;
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) { if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine); convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
convW_dst_memory); convW_dst_memory);
} }
if (gradB != nullptr) { if (gradB != nullptr) {
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine, auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine,
gradB->buffer()); gradB->buffer());
convolution_backward_weights(convW_prim_desc).execute(stream, convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory}, {{DNNL_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}, {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}}); {DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
} else { } else {
convolution_backward_weights(convW_prim_desc).execute(stream, convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory}, {{DNNL_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}}); {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
} }
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
@ -293,38 +293,38 @@ PLATFORM_IMPL(conv2d_bp) {
conv_padding, conv_padding_r); conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine); dnnl::stream stream(engine);
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
conv_prim_desc); conv_prim_desc);
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer()); auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine, auto userI_weights_memory = dnnl::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer()); const_cast<NDArray *>(weights)->buffer());
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine, auto userI_dst_memory = dnnl::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer()); const_cast<NDArray *>(gradO)->buffer());
auto convI_src_memory = userI_src_memory; auto convI_src_memory = userI_src_memory;
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine); convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
} }
auto convI_weights_memory = userI_weights_memory; auto convI_weights_memory = userI_weights_memory;
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) { if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine); convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
convI_weights_memory); convI_weights_memory);
} }
auto convI_dst_memory = userI_dst_memory; auto convI_dst_memory = userI_dst_memory;
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) { if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine); convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
convI_dst_memory); convI_dst_memory);
} }
convolution_backward_data(convI_prim_desc).execute(stream, convolution_backward_data(convI_prim_desc).execute(stream,
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory}, {{DNNL_ARG_DIFF_DST, convI_dst_memory},
{MKLDNN_ARG_WEIGHTS, convI_weights_memory}, {DNNL_ARG_WEIGHTS, convI_weights_memory},
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}}); {DNNL_ARG_DIFF_SRC, convI_src_memory}});
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h> #include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn; using namespace dnnl;
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -83,12 +83,12 @@ PLATFORM_IMPL(conv3dnew) {
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
mkldnn_memory_desc_t empty; dnnl_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
empty); empty);
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
empty); empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
isNCDHW, isNCDHW,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights,
@ -110,37 +110,37 @@ PLATFORM_IMPL(conv3dnew) {
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r); conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine); dnnl::stream stream(engine);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine); auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer()); auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto user_weights_memory = mkldnn::memory(user_weights_md, engine, auto user_weights_memory = dnnl::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer()); const_cast<NDArray *>(weights)->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto conv_src_memory = user_src_memory; auto conv_src_memory = user_src_memory;
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) { if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine); conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine);
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory); reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
} }
auto conv_weights_memory = user_weights_memory; auto conv_weights_memory = user_weights_memory;
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) { if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine); conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine);
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
conv_weights_memory); conv_weights_memory);
} }
auto conv_dst_memory = user_dst_memory; auto conv_dst_memory = user_dst_memory;
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine); conv_dst_memory = dnnl::memory(conv_prim_desc.dst_desc(), engine);
} }
if (bias != nullptr) { if (bias != nullptr) {
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine, bias->buffer()); auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->buffer());
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory}, convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_BIAS, conv_bias_memory}, {DNNL_ARG_BIAS, conv_bias_memory},
{MKLDNN_ARG_DST, conv_dst_memory}}); {DNNL_ARG_DST, conv_dst_memory}});
} else { } else {
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory}, convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_DST, conv_dst_memory}}); {DNNL_ARG_DST, conv_dst_memory}});
} }
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory); reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
@ -235,12 +235,12 @@ PLATFORM_IMPL(conv3dnew_bp) {
oC, bias->rankOf(), bias->lengthOf()); oC, bias->rankOf(), bias->lengthOf());
mkldnn_memory_desc_t empty; dnnl_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
isNDHWC, isNDHWC,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights,
@ -273,47 +273,47 @@ PLATFORM_IMPL(conv3dnew_bp) {
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine); dnnl::stream stream(engine);
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
conv_prim_desc); conv_prim_desc);
auto userW_src_memory = mkldnn::memory(user_src_md, engine, auto userW_src_memory = dnnl::memory(user_src_md, engine,
const_cast<NDArray *>(input)->buffer()); const_cast<NDArray *>(input)->buffer());
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer()); auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine, auto userW_dst_memory = dnnl::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer()); const_cast<NDArray *>(gradO)->buffer());
auto convW_src_memory = userW_src_memory; auto convW_src_memory = userW_src_memory;
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) { if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine); convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
convW_src_memory); convW_src_memory);
} }
auto convW_weights_memory = userW_weights_memory; auto convW_weights_memory = userW_weights_memory;
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine); convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine);
} }
auto convW_dst_memory = userW_dst_memory; auto convW_dst_memory = userW_dst_memory;
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) { if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine); convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
convW_dst_memory); convW_dst_memory);
} }
if (gradB != nullptr) { if (gradB != nullptr) {
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine, auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine,
gradB->buffer()); gradB->buffer());
convolution_backward_weights(convW_prim_desc).execute(stream, convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory}, {{DNNL_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}, {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}}); {DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
} else { } else {
convolution_backward_weights(convW_prim_desc).execute(stream, convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory}, {{DNNL_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}}); {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
} }
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
@ -330,38 +330,38 @@ PLATFORM_IMPL(conv3dnew_bp) {
conv_padding_r); conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine); dnnl::stream stream(engine);
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
conv_prim_desc); conv_prim_desc);
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer()); auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine, auto userI_weights_memory = dnnl::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer()); const_cast<NDArray *>(weights)->buffer());
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine, auto userI_dst_memory = dnnl::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer()); const_cast<NDArray *>(gradO)->buffer());
auto convI_src_memory = userI_src_memory; auto convI_src_memory = userI_src_memory;
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine); convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
} }
auto convI_weights_memory = userI_weights_memory; auto convI_weights_memory = userI_weights_memory;
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) { if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine); convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
convI_weights_memory); convI_weights_memory);
} }
auto convI_dst_memory = userI_dst_memory; auto convI_dst_memory = userI_dst_memory;
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) { if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine); convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
convI_dst_memory); convI_dst_memory);
} }
convolution_backward_data(convI_prim_desc).execute(stream, convolution_backward_data(convI_prim_desc).execute(stream,
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory}, {{DNNL_ARG_DIFF_DST, convI_dst_memory},
{MKLDNN_ARG_WEIGHTS, convI_weights_memory}, {DNNL_ARG_WEIGHTS, convI_weights_memory},
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}}); {DNNL_ARG_DIFF_SRC, convI_src_memory}});
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,

View File

@ -49,77 +49,77 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW); int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl); ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
mkldnn::memory::dims strides = { sH, sW }; dnnl::memory::dims strides = { sH, sW };
mkldnn::memory::dims padding = { pH, pW }; dnnl::memory::dims padding = { pH, pW };
mkldnn::memory::dims padding_r = { pHmkl, pWmkl }; dnnl::memory::dims padding_r = { pHmkl, pWmkl };
mkldnn::memory::dims dilation = { dHmkl, dWmkl }; dnnl::memory::dims dilation = { dHmkl, dWmkl };
// input type // input type
mkldnn::memory::data_type xType; dnnl::memory::data_type xType;
if(input->dataType() == DataType::FLOAT32) if(input->dataType() == DataType::FLOAT32)
xType = mkldnn::memory::data_type::f32; xType = dnnl::memory::data_type::f32;
else if(input->dataType() == DataType::HALF) else if(input->dataType() == DataType::HALF)
xType = mkldnn::memory::data_type::f16; xType = dnnl::memory::data_type::f16;
else if(input->dataType() == DataType::UINT8) else if(input->dataType() == DataType::UINT8)
xType = mkldnn::memory::data_type::u8; xType = dnnl::memory::data_type::u8;
else else
xType = mkldnn::memory::data_type::s8; xType = dnnl::memory::data_type::s8;
// weights type // weights type
mkldnn::memory::data_type wType = xType; dnnl::memory::data_type wType = xType;
if(xType == mkldnn::memory::data_type::u8) if(xType == dnnl::memory::data_type::u8)
wType = mkldnn::memory::data_type::s8; wType = dnnl::memory::data_type::s8;
// output and bias type (have the same types) // output and bias type (have the same types)
mkldnn::memory::data_type zType; dnnl::memory::data_type zType;
if(output->dataType() == DataType::FLOAT32) if(output->dataType() == DataType::FLOAT32)
zType = mkldnn::memory::data_type::f32; zType = dnnl::memory::data_type::f32;
else if(output->dataType() == DataType::HALF) else if(output->dataType() == DataType::HALF)
zType = mkldnn::memory::data_type::f16; zType = dnnl::memory::data_type::f16;
else if(output->dataType() == DataType::UINT8) else if(output->dataType() == DataType::UINT8)
zType = mkldnn::memory::data_type::u8; zType = dnnl::memory::data_type::u8;
else if(output->dataType() == DataType::INT8) else if(output->dataType() == DataType::INT8)
zType = mkldnn::memory::data_type::s8; zType = dnnl::memory::data_type::s8;
else else
zType = mkldnn::memory::data_type::s32; zType = dnnl::memory::data_type::s32;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc; dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw; dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
mkldnn::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kH, kW}; dnnl::memory::dims wDims = {oC, iC, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oH, oW}; dnnl::memory::dims zDims = {bS, oC, oH, oW};
// memory descriptors for arrays // memory descriptors for arrays
// input // input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0]; x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1]; x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2]; x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3]; x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
// weights // weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3]; w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
// bias // bias
mkldnn::memory::desc b_mkl_md; dnnl::memory::desc b_mkl_md;
if(bias != nullptr) if(bias != nullptr)
b_mkl_md = mkldnn::memory::desc({oC}, zType, mkldnn::memory::format_tag::x); b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x);
// output // output
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(zDims, zType, mkldnn::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(zDims, zType, xFormat); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
z_user_md.data.format_kind = mkldnn_blocked; // overrides format z_user_md.data.format_kind = dnnl_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0]; z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1]; z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2]; z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
@ -128,51 +128,51 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// operation primitive description // operation primitive description
mkldnn::deconvolution_forward::desc op_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, dnnl::deconvolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct,
x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine); dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine);
// arguments (memory buffers) necessary for calculations // arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args; std::unordered_map<int, dnnl::memory> args;
mkldnn::stream stream(engine); dnnl::stream stream(engine);
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer()); auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_prim_desc.src_desc(), engine) : x_user_mem; auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder) if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem; args[DNNL_ARG_SRC] = x_mkl_mem;
// weights // weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer()); auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc(); const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_prim_desc.weights_desc(), engine) : w_user_mem; auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder) if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// bias // bias
if(bias != nullptr) { if(bias != nullptr) {
auto b_mkl_mem = mkldnn::memory(b_mkl_md, engine, bias->getBuffer()); auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer());
args[MKLDNN_ARG_BIAS] = b_mkl_mem; args[DNNL_ARG_BIAS] = b_mkl_mem;
} }
// output // output
auto z_user_mem = mkldnn::memory(z_user_md, engine, output->getBuffer()); auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? mkldnn::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[MKLDNN_ARG_DST] = z_mkl_mem; args[DNNL_ARG_DST] = z_mkl_mem;
// run calculations // run calculations
mkldnn::deconvolution_forward(op_prim_desc).execute(stream, args); dnnl::deconvolution_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary // reorder outputs if necessary
if (zReorder) if (zReorder)
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
stream.wait(); stream.wait();
@ -196,157 +196,157 @@ static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW); int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl); ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
mkldnn::memory::dims strides = { sH, sW }; dnnl::memory::dims strides = { sH, sW };
mkldnn::memory::dims padding = { pH, pW }; dnnl::memory::dims padding = { pH, pW };
mkldnn::memory::dims padding_r = { pHmkl, pWmkl }; dnnl::memory::dims padding_r = { pHmkl, pWmkl };
mkldnn::memory::dims dilation = { dHmkl, dWmkl }; dnnl::memory::dims dilation = { dHmkl, dWmkl };
// input type // input type
mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// weights type // weights type
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradO type // gradO type
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradI type // gradI type
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradW type // gradW type
mkldnn::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradB type // gradB type
mkldnn::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16) : mkldnn::memory::data_type::f32; dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc; dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw; dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
mkldnn::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kH, kW}; dnnl::memory::dims wDims = {oC, iC, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oH, oW}; dnnl::memory::dims zDims = {bS, oC, oH, oW};
// memory descriptors for arrays // memory descriptors for arrays
// input // input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0]; x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1]; x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2]; x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3]; x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
// weights // weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3]; w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
// gradO // gradO
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0]; gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1]; gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2]; gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3]; gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
// gradI // gradI
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0]; gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1]; gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2]; gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3]; gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
// gradW // gradW
mkldnn::memory::desc gradW_mkl_md = mkldnn::memory::desc(wDims, gradWType, mkldnn::memory::format_tag::any); dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
mkldnn::memory::desc gradW_user_md = mkldnn::memory::desc(wDims, gradWType, wFormat); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
gradW_user_md.data.format_kind = mkldnn_blocked; // overrides format gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0]; gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1]; gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2]; gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3]; gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3];
// gradB // gradB
mkldnn::memory::desc gradB_mkl_md; dnnl::memory::desc gradB_mkl_md;
if(gradB != nullptr) if(gradB != nullptr)
gradB_mkl_md = mkldnn::memory::desc({oC}, gradBType, mkldnn::memory::format_tag::x); gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// forward primitive description // forward primitive description
mkldnn::deconvolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); dnnl::deconvolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// backward data primitive description // backward data primitive description
mkldnn::deconvolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); dnnl::deconvolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
// backward weights primitive description // backward weights primitive description
mkldnn::deconvolution_backward_weights::desc op_weights_bp_desc(mkldnn::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); dnnl::deconvolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
// arguments (memory buffers) necessary for calculations // arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args; std::unordered_map<int, dnnl::memory> args;
mkldnn::stream stream(engine); dnnl::stream stream(engine);
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer()); auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc(); const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem; auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder) if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem; args[DNNL_ARG_SRC] = x_mkl_mem;
// weights // weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer()); auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder) if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// gradO // gradO
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer()); auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
if (gradOReorder) if (gradOReorder)
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem; args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
// gradI // gradI
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer()); auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem; args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
// gradW // gradW
auto gradW_user_mem = mkldnn::memory(gradW_user_md, engine, gradW->getBuffer()); auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer());
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
auto gradW_mkl_mem = gradWReorder ? mkldnn::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
args[MKLDNN_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
// gradB // gradB
if(gradB != nullptr) { if(gradB != nullptr) {
auto gradB_mkl_mem = mkldnn::memory(gradB_mkl_md, engine, gradB->getBuffer()); auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer());
args[MKLDNN_ARG_DIFF_BIAS] = gradB_mkl_mem; args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem;
} }
// run backward data calculations // run backward data calculations
mkldnn::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
// run backward weights calculations // run backward weights calculations
mkldnn::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary // reorder gradI if necessary
if (gradIReorder) if (gradIReorder)
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
if (gradWReorder) if (gradWReorder)
mkldnn::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
stream.wait(); stream.wait();

View File

@ -39,52 +39,52 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC] // weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC]
// gradO [bS, oH, oW, oC] // gradO [bS, oH, oW, oC]
mkldnn::memory::dims strides = { sH, sW }; dnnl::memory::dims strides = { sH, sW };
mkldnn::memory::dims dilation = { dH - 1, dW - 1 }; dnnl::memory::dims dilation = { dH - 1, dW - 1 };
mkldnn::memory::dims padding = { pH, pW }; dnnl::memory::dims padding = { pH, pW };
mkldnn::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
// weights type // weights type
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradO type // gradO type
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradI type // gradI type
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc; dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw; dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
mkldnn::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kH, kW}; dnnl::memory::dims wDims = {oC, iC, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oH, oW}; dnnl::memory::dims zDims = {bS, oC, oH, oW};
// memory descriptors for arrays // memory descriptors for arrays
// input // input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, gradOType, mkldnn::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, gradOType, dnnl::memory::format_tag::any);
// weights // weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3]; w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
// gradO // gradO
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0]; gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1]; gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2]; gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3]; gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
// gradI // gradI
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0]; gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1]; gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2]; gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
@ -94,48 +94,48 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// forward primitive description // forward primitive description
mkldnn::convolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// backward data primitive description // backward data primitive description
mkldnn::convolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
// arguments (memory buffers) necessary for calculations // arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args; std::unordered_map<int, dnnl::memory> args;
mkldnn::stream stream(engine); dnnl::stream stream(engine);
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// weights // weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer()); auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder) if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// gradO // gradO
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer()); auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
if (gradOReorder) if (gradOReorder)
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem; args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
// gradI // gradI
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer()); auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem; args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
// run backward data calculations // run backward data calculations
mkldnn::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary // reorder gradI if necessary
if (gradIReorder) if (gradIReorder)
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
stream.wait(); stream.wait();

View File

@ -50,54 +50,54 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW); int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl); ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
mkldnn::memory::dims strides = { sD, sH, sW }; dnnl::memory::dims strides = { sD, sH, sW };
mkldnn::memory::dims padding = { pD, pH, pW }; dnnl::memory::dims padding = { pD, pH, pW };
mkldnn::memory::dims padding_r = { pDmkl, pHmkl, pWmkl }; dnnl::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
mkldnn::memory::dims dilation = { dDmkl, dHmkl, dWmkl }; dnnl::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
// input type // input type
mkldnn::memory::data_type xType; dnnl::memory::data_type xType;
if(input->dataType() == DataType::FLOAT32) if(input->dataType() == DataType::FLOAT32)
xType = mkldnn::memory::data_type::f32; xType = dnnl::memory::data_type::f32;
else if(input->dataType() == DataType::HALF) else if(input->dataType() == DataType::HALF)
xType = mkldnn::memory::data_type::f16; xType = dnnl::memory::data_type::f16;
else if(input->dataType() == DataType::UINT8) else if(input->dataType() == DataType::UINT8)
xType = mkldnn::memory::data_type::u8; xType = dnnl::memory::data_type::u8;
else else
xType = mkldnn::memory::data_type::s8; xType = dnnl::memory::data_type::s8;
// weights type // weights type
mkldnn::memory::data_type wType = xType; dnnl::memory::data_type wType = xType;
if(xType == mkldnn::memory::data_type::u8) if(xType == dnnl::memory::data_type::u8)
wType = mkldnn::memory::data_type::s8; wType = dnnl::memory::data_type::s8;
// output and bias type (have the same types) // output and bias type (have the same types)
mkldnn::memory::data_type zType; dnnl::memory::data_type zType;
if(output->dataType() == DataType::FLOAT32) if(output->dataType() == DataType::FLOAT32)
zType = mkldnn::memory::data_type::f32; zType = dnnl::memory::data_type::f32;
else if(output->dataType() == DataType::HALF) else if(output->dataType() == DataType::HALF)
zType = mkldnn::memory::data_type::f16; zType = dnnl::memory::data_type::f16;
else if(output->dataType() == DataType::UINT8) else if(output->dataType() == DataType::UINT8)
zType = mkldnn::memory::data_type::u8; zType = dnnl::memory::data_type::u8;
else if(output->dataType() == DataType::INT8) else if(output->dataType() == DataType::INT8)
zType = mkldnn::memory::data_type::s8; zType = dnnl::memory::data_type::s8;
else else
zType = mkldnn::memory::data_type::s32; zType = dnnl::memory::data_type::s32;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::ncdhw; dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oidhw; dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
mkldnn::memory::dims xDims = {bS, iC, iD, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kD, kH, kW}; dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oD, oH, oW}; dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
// memory descriptors for arrays // memory descriptors for arrays
// input // input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0]; x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1]; x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2]; x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
@ -105,9 +105,9 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4]; x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
// weights // weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
@ -115,14 +115,14 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4]; w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
// bias // bias
mkldnn::memory::desc b_mkl_md; dnnl::memory::desc b_mkl_md;
if(bias != nullptr) if(bias != nullptr)
b_mkl_md = mkldnn::memory::desc({oC}, zType, mkldnn::memory::format_tag::x); b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x);
// output // output
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(zDims, zType, mkldnn::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(zDims, zType, xFormat); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
z_user_md.data.format_kind = mkldnn_blocked; // overrides format z_user_md.data.format_kind = dnnl_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0]; z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1]; z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2]; z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
@ -132,51 +132,51 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// operation primitive description // operation primitive description
mkldnn::deconvolution_forward::desc op_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, dnnl::deconvolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct,
x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine); dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine);
// arguments (memory buffers) necessary for calculations // arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args; std::unordered_map<int, dnnl::memory> args;
mkldnn::stream stream(engine); dnnl::stream stream(engine);
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer()); auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_prim_desc.src_desc(), engine) : x_user_mem; auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder) if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem; args[DNNL_ARG_SRC] = x_mkl_mem;
// weights // weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer()); auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc(); const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_prim_desc.weights_desc(), engine) : w_user_mem; auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder) if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// bias // bias
if(bias != nullptr) { if(bias != nullptr) {
auto b_mkl_mem = mkldnn::memory(b_mkl_md, engine, bias->getBuffer()); auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer());
args[MKLDNN_ARG_BIAS] = b_mkl_mem; args[DNNL_ARG_BIAS] = b_mkl_mem;
} }
// output // output
auto z_user_mem = mkldnn::memory(z_user_md, engine, output->getBuffer()); auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? mkldnn::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[MKLDNN_ARG_DST] = z_mkl_mem; args[DNNL_ARG_DST] = z_mkl_mem;
// run calculations // run calculations
mkldnn::deconvolution_forward(op_prim_desc).execute(stream, args); dnnl::deconvolution_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary // reorder outputs if necessary
if (zReorder) if (zReorder)
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
stream.wait(); stream.wait();
@ -200,37 +200,37 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW); int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl); ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
mkldnn::memory::dims strides = { sD, sH, sW }; dnnl::memory::dims strides = { sD, sH, sW };
mkldnn::memory::dims padding = { pD, pH, pW }; dnnl::memory::dims padding = { pD, pH, pW };
mkldnn::memory::dims padding_r = { pDmkl, pHmkl, pWmkl }; dnnl::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
mkldnn::memory::dims dilation = { dDmkl, dHmkl, dWmkl }; dnnl::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
// input type // input type
mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// weights type // weights type
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradO type // gradO type
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradI type // gradI type
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradW type // gradW type
mkldnn::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradB type // gradB type
mkldnn::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16) : mkldnn::memory::data_type::f32; dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::ncdhw; // isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc; dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw; // isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oidhw; dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
mkldnn::memory::dims xDims = {bS, iC, iD, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kD, kH, kW}; dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oD, oH, oW}; dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
// memory descriptors for arrays // memory descriptors for arrays
// input // input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0]; x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1]; x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2]; x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
@ -238,9 +238,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4]; x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
// weights // weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
@ -248,9 +248,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4]; w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
// gradO // gradO
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0]; gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1]; gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2]; gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
@ -258,9 +258,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->stridesOf()[4]; gradO_user_md.data.format_desc.blocking.strides[4] = gradO->stridesOf()[4];
// gradI // gradI
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0]; gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1]; gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2]; gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
@ -268,9 +268,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->stridesOf()[4]; gradI_user_md.data.format_desc.blocking.strides[4] = gradI->stridesOf()[4];
// gradW // gradW
mkldnn::memory::desc gradW_mkl_md = mkldnn::memory::desc(wDims, gradWType, wFormat); dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat);
mkldnn::memory::desc gradW_user_md = mkldnn::memory::desc(wDims, gradWType, wFormat); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
gradW_user_md.data.format_kind = mkldnn_blocked; // overrides format gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0]; gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1]; gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2]; gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
@ -278,85 +278,85 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->stridesOf()[4]; gradW_user_md.data.format_desc.blocking.strides[4] = gradW->stridesOf()[4];
// gradB // gradB
mkldnn::memory::desc gradB_mkl_md; dnnl::memory::desc gradB_mkl_md;
if(gradB != nullptr) if(gradB != nullptr)
gradB_mkl_md = mkldnn::memory::desc({oC}, gradBType, mkldnn::memory::format_tag::x); gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// forward primitive description // forward primitive description
mkldnn::deconvolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); dnnl::deconvolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// backward data primitive description // backward data primitive description
mkldnn::deconvolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); dnnl::deconvolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
// backward weights primitive description // backward weights primitive description
mkldnn::deconvolution_backward_weights::desc op_weights_bp_desc(mkldnn::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); dnnl::deconvolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
// arguments (memory buffers) necessary for calculations // arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args; std::unordered_map<int, dnnl::memory> args;
mkldnn::stream stream(engine); dnnl::stream stream(engine);
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer()); auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc(); const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem; auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder) if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem; args[DNNL_ARG_SRC] = x_mkl_mem;
// weights // weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer()); auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder) if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// gradO // gradO
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer()); auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
if (gradOReorder) if (gradOReorder)
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem; args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
// gradI // gradI
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer()); auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem; args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
// gradW // gradW
auto gradW_user_mem = mkldnn::memory(gradW_user_md, engine, gradW->getBuffer()); auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer());
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
auto gradW_mkl_mem = gradWReorder ? mkldnn::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
args[MKLDNN_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
// gradB // gradB
if(gradB != nullptr) { if(gradB != nullptr) {
auto gradB_mkl_mem = mkldnn::memory(gradB_mkl_md, engine, gradB->getBuffer()); auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer());
args[MKLDNN_ARG_DIFF_BIAS] = gradB_mkl_mem; args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem;
} }
// run backward data calculations // run backward data calculations
mkldnn::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
// run backward weights calculations // run backward weights calculations
mkldnn::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary // reorder gradI if necessary
if (gradIReorder) if (gradIReorder)
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
if (gradWReorder) if (gradWReorder)
mkldnn::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
stream.wait(); stream.wait();

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h> #include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn; using namespace dnnl;
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -44,8 +44,8 @@ namespace nd4j {
double bias = T_ARG(0); double bias = T_ARG(0);
int depth = INT_ARG(0); int depth = INT_ARG(0);
mkldnn_memory_desc_t empty; dnnl_memory_desc_t empty;
mkldnn::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty); dnnl::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty);
mkldnnUtils::getMKLDNNMemoryDescLrn(input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md, mkldnnUtils::getMKLDNNMemoryDescLrn(input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md,
&user_src_md, nullptr, &user_dst_md, input->rankOf() - 1); &user_src_md, nullptr, &user_dst_md, input->rankOf() - 1);
@ -54,24 +54,24 @@ namespace nd4j {
lrn_src_md, (2 * depth + 1), alpha * (2 * depth + 1), beta, bias); lrn_src_md, (2 * depth + 1), alpha * (2 * depth + 1), beta, bias);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine); dnnl::stream stream(engine);
auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine); auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto lrn_src_memory = user_src_memory; auto lrn_src_memory = user_src_memory;
if (lrn_prim_desc.src_desc() != user_src_memory.get_desc()) { if (lrn_prim_desc.src_desc() != user_src_memory.get_desc()) {
lrn_src_memory = mkldnn::memory(lrn_prim_desc.src_desc(), engine); lrn_src_memory = dnnl::memory(lrn_prim_desc.src_desc(), engine);
reorder(user_src_memory, lrn_src_memory).execute(stream, user_src_memory, lrn_src_memory); reorder(user_src_memory, lrn_src_memory).execute(stream, user_src_memory, lrn_src_memory);
} }
auto lrn_dst_memory = user_dst_memory; auto lrn_dst_memory = user_dst_memory;
if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) { if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
lrn_dst_memory = mkldnn::memory(lrn_prim_desc.dst_desc(), engine); lrn_dst_memory = dnnl::memory(lrn_prim_desc.dst_desc(), engine);
} }
lrn_forward(lrn_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, lrn_src_memory}, lrn_forward(lrn_prim_desc).execute(stream, {{DNNL_ARG_SRC, lrn_src_memory},
{MKLDNN_ARG_DST, lrn_dst_memory}}); {DNNL_ARG_DST, lrn_dst_memory}});
if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) { if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(lrn_dst_memory, user_dst_memory).execute(stream, lrn_dst_memory, user_dst_memory); reorder(lrn_dst_memory, user_dst_memory).execute(stream, lrn_dst_memory, user_dst_memory);

View File

@ -21,7 +21,7 @@
#include <ops/declarable/OpRegistrator.h> #include <ops/declarable/OpRegistrator.h>
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
using namespace mkldnn; using namespace dnnl;
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -132,52 +132,52 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md, dnnl::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md,
x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md; x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md;
// input type // input type
mkldnn::memory::data_type xType; dnnl::memory::data_type xType;
if(x->dataType() == DataType::FLOAT32) if(x->dataType() == DataType::FLOAT32)
xType = mkldnn::memory::data_type::f32; xType = dnnl::memory::data_type::f32;
else if(x->dataType() == DataType::HALF) else if(x->dataType() == DataType::HALF)
xType = mkldnn::memory::data_type::f16; xType = dnnl::memory::data_type::f16;
else else
xType = mkldnn::memory::data_type::u8; xType = dnnl::memory::data_type::u8;
// weights type // weights type
mkldnn::memory::data_type wType = xType; dnnl::memory::data_type wType = xType;
if(xType == mkldnn::memory::data_type::u8) if(xType == dnnl::memory::data_type::u8)
wType = mkldnn::memory::data_type::s8; wType = dnnl::memory::data_type::s8;
// bias type // bias type
mkldnn::memory::data_type bType = xType; dnnl::memory::data_type bType = xType;
if(xType == mkldnn::memory::data_type::u8) if(xType == dnnl::memory::data_type::u8)
bType = mkldnn::memory::data_type::f32; bType = dnnl::memory::data_type::f32;
// output type // output type
mkldnn::memory::data_type hType; dnnl::memory::data_type hType;
if(h->dataType() == DataType::FLOAT32) if(h->dataType() == DataType::FLOAT32)
hType = mkldnn::memory::data_type::f32; hType = dnnl::memory::data_type::f32;
else if(h->dataType() == DataType::HALF) else if(h->dataType() == DataType::HALF)
hType = mkldnn::memory::data_type::f16; hType = dnnl::memory::data_type::f16;
else else
hType = mkldnn::memory::data_type::u8; hType = dnnl::memory::data_type::u8;
// memory descriptors for arrays // memory descriptors for arrays
// x // x
x_lstm_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::any); x_lstm_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any);
// x_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, nIn}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, nIn}, type, mkldnn::memory::format_tag::ntc); // x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, dnnl::memory::format_tag::ntc);
x_user_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::tnc); x_user_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
// wx // wx
wx_lstm_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::any); wx_lstm_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::any);
wx_user_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::ldigo); wx_user_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::ldigo);
wx_user_md.data.format_kind = mkldnn_blocked; // overrides format wx_user_md.data.format_kind = dnnl_blocked; // overrides format
wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0]; wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0];
wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1]; wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1];
wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2]; wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2];
@ -185,9 +185,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4]; wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4];
// wr // wr
wr_lstm_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::any); wr_lstm_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::any);
wr_user_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::ldigo); wr_user_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::ldigo);
wr_user_md.data.format_kind = mkldnn_blocked; // overrides format wr_user_md.data.format_kind = dnnl_blocked; // overrides format
wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0]; wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0];
wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1]; wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1];
wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2]; wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2];
@ -195,19 +195,19 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4]; wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4];
// h // h
h_lstm_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::any); h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::any);
// h_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, hDirDim*nOut}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, hDirDim*nOut}, type, mkldnn::memory::format_tag::ntc); // h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, hDirDim*nOut}, type, dnnl::memory::format_tag::ntc);
h_user_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::tnc); h_user_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::tnc);
h_user_md.data.format_kind = mkldnn_blocked; // overrides format h_user_md.data.format_kind = dnnl_blocked; // overrides format
h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0]; h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0];
h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1]; h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1];
h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2]; h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2];
// b // b
if(b) { if(b) {
b_lstm_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::any); b_lstm_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::any);
b_user_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::ldgo); b_user_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::ldgo);
b_user_md.data.format_kind = mkldnn_blocked; // overrides format b_user_md.data.format_kind = dnnl_blocked; // overrides format
b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0]; b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0];
b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1]; b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1];
b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2]; b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2];
@ -216,9 +216,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
// hI // hI
if(hI) { if(hI) {
hI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any); hI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any);
hI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc); hI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc);
hI_user_md.data.format_kind = mkldnn_blocked; // overrides format hI_user_md.data.format_kind = dnnl_blocked; // overrides format
hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0]; hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0];
hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1]; hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1];
hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2]; hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2];
@ -227,9 +227,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
// cI // cI
if(cI) { if(cI) {
cI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any); cI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any);
cI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc); cI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc);
cI_user_md.data.format_kind = mkldnn_blocked; // overrides format cI_user_md.data.format_kind = dnnl_blocked; // overrides format
cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0]; cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0];
cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1]; cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1];
cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2]; cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2];
@ -238,9 +238,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
// hL // hL
if(hL) { if(hL) {
hL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::any); hL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::any);
hL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc); hL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
hL_user_md.data.format_kind = mkldnn_blocked; // overrides format hL_user_md.data.format_kind = dnnl_blocked; // overrides format
hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0]; hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0];
hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1]; hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1];
hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2]; hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2];
@ -248,9 +248,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
} }
if(cL) { if(cL) {
cL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc); cL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
cL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc); cL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
cL_user_md.data.format_kind = mkldnn_blocked; // overrides format cL_user_md.data.format_kind = dnnl_blocked; // overrides format
cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0]; cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0];
cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1]; cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1];
cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2]; cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2];
@ -262,92 +262,92 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md,
h_lstm_md, hL_lstm_md, cL_lstm_md); h_lstm_md, hL_lstm_md, cL_lstm_md);
mkldnn::stream stream(engine); dnnl::stream stream(engine);
// lstm primitive description // lstm primitive description
lstm_forward::primitive_desc lstm_prim_desc(lstm_desc, engine); lstm_forward::primitive_desc lstm_prim_desc(lstm_desc, engine);
// arguments (memory buffers) necessary for calculations // arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args; std::unordered_map<int, dnnl::memory> args;
// provide memory and check whether reorder is required // provide memory and check whether reorder is required
// x // x
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer()); auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
const bool xReorder = lstm_prim_desc.src_layer_desc() != x_user_mem.get_desc(); const bool xReorder = lstm_prim_desc.src_layer_desc() != x_user_mem.get_desc();
auto x_lstm_mem = xReorder ? mkldnn::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem; auto x_lstm_mem = xReorder ? dnnl::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem;
if (xReorder) if (xReorder)
reorder(x_user_mem, x_lstm_mem).execute(stream, x_user_mem, x_lstm_mem); reorder(x_user_mem, x_lstm_mem).execute(stream, x_user_mem, x_lstm_mem);
args[MKLDNN_ARG_SRC_LAYER] = x_lstm_mem; args[DNNL_ARG_SRC_LAYER] = x_lstm_mem;
// wx // wx
auto wx_user_mem = mkldnn::memory(wx_user_md, engine, Wx->getBuffer()); auto wx_user_mem = dnnl::memory(wx_user_md, engine, Wx->getBuffer());
const bool wxReorder = lstm_prim_desc.weights_layer_desc()!= wx_user_mem.get_desc(); const bool wxReorder = lstm_prim_desc.weights_layer_desc()!= wx_user_mem.get_desc();
auto wx_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem; auto wx_lstm_mem = wxReorder ? dnnl::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem;
if (wxReorder) if (wxReorder)
reorder(wx_user_mem, wx_lstm_mem).execute(stream, wx_user_mem, wx_lstm_mem); reorder(wx_user_mem, wx_lstm_mem).execute(stream, wx_user_mem, wx_lstm_mem);
args[MKLDNN_ARG_WEIGHTS_LAYER] = wx_lstm_mem; args[DNNL_ARG_WEIGHTS_LAYER] = wx_lstm_mem;
// wr // wr
auto wr_user_mem = mkldnn::memory(wr_user_md, engine, Wr->getBuffer()); auto wr_user_mem = dnnl::memory(wr_user_md, engine, Wr->getBuffer());
const bool wrReorder = lstm_prim_desc.weights_iter_desc() != wr_user_mem.get_desc(); const bool wrReorder = lstm_prim_desc.weights_iter_desc() != wr_user_mem.get_desc();
auto wr_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem; auto wr_lstm_mem = wxReorder ? dnnl::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem;
if (wrReorder) if (wrReorder)
reorder(wr_user_mem, wr_lstm_mem).execute(stream, wr_user_mem, wr_lstm_mem); reorder(wr_user_mem, wr_lstm_mem).execute(stream, wr_user_mem, wr_lstm_mem);
args[MKLDNN_ARG_WEIGHTS_ITER] = wr_lstm_mem; args[DNNL_ARG_WEIGHTS_ITER] = wr_lstm_mem;
// h // h
auto h_user_mem = mkldnn::memory(h_user_md, engine, h->getBuffer()); auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer());
const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc(); const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc();
auto h_lstm_mem = hReorder ? mkldnn::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem; auto h_lstm_mem = hReorder ? dnnl::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem;
args[MKLDNN_ARG_DST_LAYER] = h_lstm_mem; args[DNNL_ARG_DST_LAYER] = h_lstm_mem;
// b // b
if(b) { if(b) {
auto b_user_mem = mkldnn::memory(b_user_md, engine, b->getBuffer()); auto b_user_mem = dnnl::memory(b_user_md, engine, b->getBuffer());
const bool bReorder = lstm_prim_desc.bias_desc() != b_user_mem.get_desc(); const bool bReorder = lstm_prim_desc.bias_desc() != b_user_mem.get_desc();
auto b_lstm_mem = bReorder ? mkldnn::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem; auto b_lstm_mem = bReorder ? dnnl::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem;
if (bReorder) if (bReorder)
reorder(b_user_mem, b_lstm_mem).execute(stream, b_user_mem, b_lstm_mem); reorder(b_user_mem, b_lstm_mem).execute(stream, b_user_mem, b_lstm_mem);
args[MKLDNN_ARG_BIAS] = b_lstm_mem; args[DNNL_ARG_BIAS] = b_lstm_mem;
} }
// hI // hI
if(hI) { if(hI) {
auto hI_user_mem = mkldnn::memory(hI_user_md, engine, hI->getBuffer()); auto hI_user_mem = dnnl::memory(hI_user_md, engine, hI->getBuffer());
const bool hIReorder = lstm_prim_desc.src_iter_desc() != hI_user_mem.get_desc(); const bool hIReorder = lstm_prim_desc.src_iter_desc() != hI_user_mem.get_desc();
auto hI_lstm_mem = hIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem; auto hI_lstm_mem = hIReorder ? dnnl::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem;
if (hIReorder) if (hIReorder)
reorder(hI_user_mem, hI_lstm_mem).execute(stream, hI_user_mem, hI_lstm_mem); reorder(hI_user_mem, hI_lstm_mem).execute(stream, hI_user_mem, hI_lstm_mem);
args[MKLDNN_ARG_SRC_ITER] = hI_lstm_mem; args[DNNL_ARG_SRC_ITER] = hI_lstm_mem;
} }
// cI // cI
if(cI) { if(cI) {
auto cI_user_mem = mkldnn::memory(cI_user_md, engine, cI->getBuffer()); auto cI_user_mem = dnnl::memory(cI_user_md, engine, cI->getBuffer());
const bool cIReorder = lstm_prim_desc.src_iter_c_desc() != cI_user_mem.get_desc(); const bool cIReorder = lstm_prim_desc.src_iter_c_desc() != cI_user_mem.get_desc();
auto cI_lstm_mem = cIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem; auto cI_lstm_mem = cIReorder ? dnnl::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem;
if (cIReorder) if (cIReorder)
reorder(cI_user_mem, cI_lstm_mem).execute(stream, cI_user_mem, cI_lstm_mem); reorder(cI_user_mem, cI_lstm_mem).execute(stream, cI_user_mem, cI_lstm_mem);
args[MKLDNN_ARG_SRC_ITER_C] = cI_lstm_mem; args[DNNL_ARG_SRC_ITER_C] = cI_lstm_mem;
} }
bool hLReorder(false), cLReorder(false); bool hLReorder(false), cLReorder(false);
mkldnn::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem; dnnl::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem;
// hL // hL
if(hL) { if(hL) {
hL_user_mem = mkldnn::memory(hL_user_md, engine, hL->getBuffer()); hL_user_mem = dnnl::memory(hL_user_md, engine, hL->getBuffer());
hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc(); hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc();
hL_lstm_mem = hLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem; hL_lstm_mem = hLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem;
args[MKLDNN_ARG_DST_ITER] = hL_lstm_mem; args[DNNL_ARG_DST_ITER] = hL_lstm_mem;
} }
// cL // cL
if(cL) { if(cL) {
cL_user_mem = mkldnn::memory(cL_user_md, engine, cL->getBuffer()); cL_user_mem = dnnl::memory(cL_user_md, engine, cL->getBuffer());
cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc(); cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc();
cL_lstm_mem = cLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem; cL_lstm_mem = cLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem;
args[MKLDNN_ARG_DST_ITER_C] = cL_lstm_mem; args[DNNL_ARG_DST_ITER_C] = cL_lstm_mem;
} }
// run calculations // run calculations

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h> #include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn; using namespace dnnl;
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -82,11 +82,11 @@ namespace nd4j {
auto poolingMode = PoolingType::MAX_POOL; auto poolingMode = PoolingType::MAX_POOL;
int extraParam0 = 1; int extraParam0 = 1;
mkldnn_memory_desc_t empty; dnnl_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty); dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty); dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm; dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true, true,
@ -102,23 +102,23 @@ namespace nd4j {
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory; auto pool_src_memory = user_src_memory;
mkldnn::stream stream(engine); dnnl::stream stream(engine);
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine); pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
} }
auto pool_dst_memory = user_dst_memory; auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine); pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
} }
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory}, pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory}}); {DNNL_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h> #include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn; using namespace dnnl;
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -89,11 +89,11 @@ namespace nd4j {
auto poolingMode = PoolingType::MAX_POOL; auto poolingMode = PoolingType::MAX_POOL;
mkldnn_memory_desc_t empty; dnnl_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm; dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true, true,
@ -109,44 +109,44 @@ namespace nd4j {
pool_padding_r); pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine); dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r); pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = mkldnn::memory(user_src_md, engine, gradI->buffer()); auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer()); auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory; auto poolB_src_memory = userB_src_memory;
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine); poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
} }
auto poolB_dst_memory = userB_dst_memory; auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine); poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
} }
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto pool_src_memory = user_src_memory; auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine); pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
} }
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine); auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_desc(), engine); auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory}, pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory}, {DNNL_ARG_DST, pool_dst_memory},
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory}}); {DNNL_ARG_WORKSPACE, pool_workspace_memory}});
// probably wrong, fix that // probably wrong, fix that
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory}, pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory}, {DNNL_ARG_WORKSPACE, pool_workspace_memory},
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}}); {DNNL_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);

View File

@ -26,7 +26,7 @@
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h> #include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn; using namespace dnnl;
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -87,11 +87,11 @@ namespace nd4j {
auto poolingMode = PoolingType::MAX_POOL; auto poolingMode = PoolingType::MAX_POOL;
auto extraParam0 = 1; auto extraParam0 = 1;
mkldnn_memory_desc_t empty; dnnl_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty); dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty); dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm; dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true, extraParam0, true,
@ -106,24 +106,24 @@ namespace nd4j {
pool_padding_r); pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine); dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory; auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine); pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
} }
auto pool_dst_memory = user_dst_memory; auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine); pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
} }
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory}, pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory}}); {DNNL_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);

View File

@ -26,7 +26,7 @@
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h> #include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn; using namespace dnnl;
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -93,11 +93,11 @@ namespace nd4j {
auto poolingMode = PoolingType::MAX_POOL; auto poolingMode = PoolingType::MAX_POOL;
auto extraParam0 = 1; auto extraParam0 = 1;
mkldnn_memory_desc_t empty; dnnl_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm; dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true, extraParam0, true,
@ -115,44 +115,44 @@ namespace nd4j {
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine); dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer()); auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer()); auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory; auto poolB_src_memory = userB_src_memory;
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine); poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
} }
auto poolB_dst_memory = userB_dst_memory; auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine); poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
} }
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto pool_src_memory = user_src_memory; auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine); pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
} }
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine); auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_desc(), engine); auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory}, pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory}, {DNNL_ARG_DST, pool_dst_memory},
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory}}); {DNNL_ARG_WORKSPACE, pool_workspace_memory}});
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory}, pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory}, {DNNL_ARG_WORKSPACE, pool_workspace_memory},
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}}); {DNNL_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {

View File

@ -18,23 +18,23 @@
// @author saudet // @author saudet
// //
#include <mkldnn_types.h> #include <dnnl_types.h>
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h> #include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn; using namespace dnnl;
namespace nd4j { namespace nd4j {
namespace mkldnnUtils { namespace mkldnnUtils {
void getMKLDNNMemoryDescPool2d( void getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, int bS, int iC, int iH, int iW, int oC, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm, const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md, dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) { dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
mkldnn::memory::dims pool_src_tz = { bS, iC, iH, iW }; dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW };
mkldnn::memory::dims pool_dst_tz = { bS, oC, oH, oW }; dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW };
pool_strides = { sH, sW }; pool_strides = { sH, sW };
pool_kernel = { kH, kW }; pool_kernel = { kH, kW };
@ -45,14 +45,14 @@ namespace nd4j {
algorithm = poolingMode == 0 ? algorithm::pooling_max algorithm = poolingMode == 0 ? algorithm::pooling_max
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
: algorithm::pooling_avg_include_padding; : algorithm::pooling_avg_include_padding;
auto type = mkldnn::memory::data_type::f32; auto type = dnnl::memory::data_type::f32;
auto format = isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc; auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any" auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
@ -60,9 +60,9 @@ namespace nd4j {
} }
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
@ -70,9 +70,9 @@ namespace nd4j {
} }
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format); *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
@ -84,12 +84,12 @@ namespace nd4j {
void getMKLDNNMemoryDescPool3d( void getMKLDNNMemoryDescPool3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm, const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md, dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) { dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
mkldnn::memory::dims pool_src_tz = { bS, iC, iD, iH, iW }; dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
mkldnn::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW }; dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
pool_strides = { sD, sH, sW }; pool_strides = { sD, sH, sW };
pool_kernel = { kD, kH, kW }; pool_kernel = { kD, kH, kW };
@ -101,14 +101,14 @@ namespace nd4j {
algorithm = poolingMode == 0 ? algorithm::pooling_max algorithm = poolingMode == 0 ? algorithm::pooling_max
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
: algorithm::pooling_avg_include_padding; : algorithm::pooling_avg_include_padding;
auto type = mkldnn::memory::data_type::f32; auto type = dnnl::memory::data_type::f32;
auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc; auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
auto supposed_to_be_any_format = mkldnn::memory::format_tag::nCdhw8c; // doesn't work with "any" auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
@ -117,9 +117,9 @@ namespace nd4j {
} }
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
@ -128,9 +128,9 @@ namespace nd4j {
} }
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format); *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
@ -145,15 +145,15 @@ namespace nd4j {
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW, int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md, dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md, dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation) { dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
mkldnn::memory::dims conv_src_tz = { bS, iC, iH, iW }; dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW };
mkldnn::memory::dims conv_weights_tz = { oC, iC, kH, kW }; dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW };
mkldnn::memory::dims conv_bias_tz = { oC }; dnnl::memory::dims conv_bias_tz = { oC };
mkldnn::memory::dims conv_dst_tz = { bS, oC, oH, oW }; dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW); int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(iH, iW, oH, oW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl); nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(iH, iW, oH, oW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
@ -169,14 +169,14 @@ namespace nd4j {
conv_padding_r = { (oH - 1) * sH - iH + kH - pH, conv_padding_r = { (oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW }; (oW - 1) * sW - iW + kW - pW };
auto type = mkldnn::memory::data_type::f32; auto type = dnnl::memory::data_type::f32;
auto format = isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc; auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
auto formatw = mkldnn::memory::format_tag::hwio; auto formatw = dnnl::memory::format_tag::hwio;
if (src != nullptr && conv_src_md != nullptr) { if (src != nullptr && conv_src_md != nullptr) {
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any); *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format); *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
@ -184,9 +184,9 @@ namespace nd4j {
} }
if (diff_src != nullptr && conv_diff_src_md != nullptr) { if (diff_src != nullptr && conv_diff_src_md != nullptr) {
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any); *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format); *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
@ -194,9 +194,9 @@ namespace nd4j {
} }
if (weights != nullptr && conv_weights_md != nullptr) { if (weights != nullptr && conv_weights_md != nullptr) {
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any); *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw); *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = hwio" user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3]; user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3];
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2]; user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2];
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
@ -204,9 +204,9 @@ namespace nd4j {
} }
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any); *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw); *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_diff_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = hwio" user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3]; user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3];
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2]; user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2];
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
@ -214,14 +214,14 @@ namespace nd4j {
} }
if (bias != nullptr && conv_bias_md != nullptr) { if (bias != nullptr && conv_bias_md != nullptr) {
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::any); *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::x); *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
} }
if (dst != nullptr && conv_dst_md != nullptr) { if (dst != nullptr && conv_dst_md != nullptr) {
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format_tag::any); *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format); *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
@ -233,15 +233,15 @@ namespace nd4j {
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW, int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md, dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md, dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation) { dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
mkldnn::memory::dims conv_src_tz = { bS, iC, iD, iH, iW }; dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
mkldnn::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW }; dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
mkldnn::memory::dims conv_bias_tz = { oC }; dnnl::memory::dims conv_bias_tz = { oC };
mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW }; dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW); int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl); nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
@ -251,14 +251,14 @@ namespace nd4j {
conv_padding_r = { pDmkl, pHmkl, pWmkl }; conv_padding_r = { pDmkl, pHmkl, pWmkl };
conv_dilation = { dDmkl, dHmkl, dWmkl }; conv_dilation = { dDmkl, dHmkl, dWmkl };
auto type = mkldnn::memory::data_type::f32; auto type = dnnl::memory::data_type::f32;
auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc; auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
auto formatw = mkldnn::memory::format_tag::dhwio; auto formatw = dnnl::memory::format_tag::dhwio;
if (src != nullptr && conv_src_md != nullptr) { if (src != nullptr && conv_src_md != nullptr) {
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any); *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format); *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
@ -267,9 +267,9 @@ namespace nd4j {
} }
if (diff_src != nullptr && conv_diff_src_md != nullptr) { if (diff_src != nullptr && conv_diff_src_md != nullptr) {
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any); *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format); *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
@ -278,9 +278,9 @@ namespace nd4j {
} }
if (weights != nullptr && conv_weights_md != nullptr) { if (weights != nullptr && conv_weights_md != nullptr) {
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any); *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw); *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = dhwio" user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4]; user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4];
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3]; user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3];
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
@ -289,9 +289,9 @@ namespace nd4j {
} }
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any); *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw); *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_diff_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = dhwio" user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4]; user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4];
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3]; user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3];
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
@ -300,14 +300,14 @@ namespace nd4j {
} }
if (bias != nullptr && conv_bias_md != nullptr) { if (bias != nullptr && conv_bias_md != nullptr) {
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::any); *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::x); *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
} }
if (dst != nullptr && conv_dst_md != nullptr) { if (dst != nullptr && conv_dst_md != nullptr) {
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format_tag::any); *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format); *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
@ -318,23 +318,23 @@ namespace nd4j {
// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, // void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
// mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md, // dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
// mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) { // dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
// const Nd4jLong* shape = src->getShapeInfo(); // const Nd4jLong* shape = src->getShapeInfo();
// Nd4jLong rank = shape[0]; // Nd4jLong rank = shape[0];
// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one // Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
// Nd4jLong dim2 = axis >= 2 ? 1 : 2; // Nd4jLong dim2 = axis >= 2 ? 1 : 2;
// Nd4jLong dim3 = axis >= 3 ? 2 : 3; // Nd4jLong dim3 = axis >= 3 ? 2 : 3;
// mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; // dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
// auto type = mkldnn::memory::data_type::f32; // auto type = dnnl::memory::data_type::f32;
// auto format = mkldnn::memory::format_tag::nchw; // auto format = dnnl::memory::format_tag::nchw;
// auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any" // auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) { // if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
// *batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); // *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); // *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
// user_src_md->data.format_kind = mkldnn_blocked; // overrides format // user_src_md->data.format_kind = dnnl_blocked; // overrides format
// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; // user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; // user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; // user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
@ -342,9 +342,9 @@ namespace nd4j {
// } // }
// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) { // if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
// *batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); // *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); // *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
// user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format // user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; // user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; // user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; // user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
@ -352,9 +352,9 @@ namespace nd4j {
// } // }
// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) { // if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
// *batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); // *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); // *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
// user_dst_md->data.format_kind = mkldnn_blocked; // overrides format // user_dst_md->data.format_kind = dnnl_blocked; // overrides format
// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; // user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; // user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; // user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
@ -364,23 +364,23 @@ namespace nd4j {
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* lrn_src_md, mkldnn::memory::desc* lrn_diff_src_md, mkldnn::memory::desc* lrn_dst_md, dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) { dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
const Nd4jLong* shape = src->getShapeInfo(); const Nd4jLong* shape = src->getShapeInfo();
long rank = shape[0]; long rank = shape[0];
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
long dim2 = axis >= 2 ? 1 : 2; long dim2 = axis >= 2 ? 1 : 2;
long dim3 = axis >= 3 ? 2 : 3; long dim3 = axis >= 3 ? 2 : 3;
mkldnn::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
auto type = mkldnn::memory::data_type::f32; auto type = dnnl::memory::data_type::f32;
auto format = axis == 1 ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc; auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
auto supposed_to_be_any_format = format; // doesn't work with "any" auto supposed_to_be_any_format = format; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) { if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
*lrn_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); *lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format); *user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; user_src_md->data.format_kind = dnnl_blocked;
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
@ -388,9 +388,9 @@ namespace nd4j {
} }
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) { if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
*lrn_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); *lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format); *user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; user_diff_src_md->data.format_kind = dnnl_blocked;
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
@ -398,9 +398,9 @@ namespace nd4j {
} }
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) { if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
*lrn_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); *lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, format); *user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; user_dst_md->data.format_kind = dnnl_blocked;
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
@ -408,8 +408,8 @@ namespace nd4j {
} }
} }
mkldnn::engine& getEngine(void *ptr) { dnnl::engine& getEngine(void *ptr) {
auto eng = reinterpret_cast<mkldnn::engine*>(ptr); auto eng = reinterpret_cast<dnnl::engine*>(ptr);
return *eng; return *eng;
} }
} }

View File

@ -23,7 +23,7 @@
#include <NativeOps.h> #include <NativeOps.h>
#include <NDArray.h> #include <NDArray.h>
#include <mkldnn.hpp> #include <dnnl.hpp>
#include <MKLDNNStream.h> #include <MKLDNNStream.h>
#include <graph/Context.h> #include <graph/Context.h>
#include <ops/declarable/PlatformHelper.h> #include <ops/declarable/PlatformHelper.h>
@ -89,47 +89,47 @@ namespace nd4j{
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW, int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md, dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md, dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation); dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation);
void getMKLDNNMemoryDescConv3d( void getMKLDNNMemoryDescConv3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW, int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md, dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md, dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation); dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation);
void getMKLDNNMemoryDescPool2d( void getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, int bS, int iC, int iH, int iW, int oC, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm, const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md, dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r); dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r);
void getMKLDNNMemoryDescPool3d( void getMKLDNNMemoryDescPool3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm, const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md, dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r); dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r);
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md, dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis); dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* lrn_src_md, mkldnn::memory::desc* lrn_diff_src_md, mkldnn::memory::desc* lrn_dst_md, dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis); dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
mkldnn::engine& getEngine(void *ptr); dnnl::engine& getEngine(void *ptr);
} }
} }

View File

@ -42,7 +42,7 @@ if ("${BUILD_MKLDNN}")
set(HAVE_MKLDNN 1) set(HAVE_MKLDNN 1)
add_definitions("-DHAVE_MKLDNN") add_definitions("-DHAVE_MKLDNN")
include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_SOURCE_DIR}/external/mklml_lnx_2019.0.3.20190220/include ${mkldnn_SOURCE_DIR}) include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_SOURCE_DIR}/external/mklml_lnx_2019.0.3.20190220/include ${mkldnn_SOURCE_DIR})
set(MKLDNN mkldnn) set(MKLDNN dnnl)
endif() endif()
# Download and unpack flatbuffers at configure time # Download and unpack flatbuffers at configure time