- MKL-DNN version upgrade to 1.1.x (#62)

- MKL-DNN namespace changes to match DNNL rename

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-20 13:23:08 +03:00 committed by GitHub
parent 7898f3c0cc
commit 59e955cedc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 836 additions and 836 deletions

View File

@ -51,7 +51,7 @@ endif()
if(NOT CUDA_BLAS)
# we need this definition to avoid global memory use within mkldnn
add_definitions(-DMKLDNN_ENABLE_CONCURRENT_EXEC=true)
add_definitions(-DDNNL_ENABLE_CONCURRENT_EXEC=true)
# there's a chance, we have no BLAS provided externally
if ("${OPENBLAS_PATH}" STREQUAL "")
@ -122,7 +122,7 @@ if(NOT CUDA_BLAS)
if (${HELPERS_mkldnn})
message("Going to pull & build mkldnn")
set(HAVE_MKLDNN 1)
set(MKLDNN_LIBRARY_TYPE "STATIC" CACHE STRING "Hack to enforce static mode" FORCE)
set(DNNL_LIBRARY_TYPE "STATIC" CACHE STRING "Hack to enforce static mode" FORCE)
configure_file(./CMakeLists.txt.mkldnn.in mkldnn-download/CMakeLists.txt)
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
@ -146,7 +146,7 @@ if(NOT CUDA_BLAS)
set(mkldnn_EXT_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src)
set(MKLDNN_PATH "${mkldnn_SOURCE_DIR}")
include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_EXT_DIR}/include ${mkldnn_SOURCE_DIR})
set(MKLDNN mkldnn)
set(MKLDNN dnnl)
endif()
endif()

View File

@ -5,11 +5,11 @@ project(mkldnn-download NONE)
include(ExternalProject)
ExternalProject_Add(mkldnn
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
GIT_TAG v1.0.4
GIT_TAG v1.1.1
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
CONFIGURE_COMMAND ""
CMAKE_ARGS -DMKLDNN_USE_MKL=ML -DMKLDNN_LIBRARY_TYPE=STATIC -G \"Unix Makefiles\"
CMAKE_ARGS -DDNNL_USE_MKL=ML -DDNNL_LIBRARY_TYPE=STATIC -G \"Unix Makefiles\"
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""

View File

@ -30,14 +30,14 @@ thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
#endif
#ifdef HAVE_MKLDNN
#include <mkldnn.hpp>
#include <dnnl.hpp>
#endif
namespace nd4j {
LaunchContext::~LaunchContext() {
#ifdef HAVE_MKLDNN
delete reinterpret_cast<mkldnn::engine*>(_engine);
delete reinterpret_cast<dnnl::engine*>(_engine);
#endif
}
@ -50,7 +50,7 @@ namespace nd4j {
_deviceID = 0;
#ifdef HAVE_MKLDNN
_engine = new mkldnn::engine(mkldnn::engine::kind::cpu, 0);
_engine = new dnnl::engine(dnnl::engine::kind::cpu, 0);
#endif
}

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
using namespace dnnl;
namespace nd4j {
namespace ops {
@ -82,11 +82,11 @@ namespace nd4j {
auto poolingMode = PoolingType::AVG_POOL;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
@ -99,20 +99,20 @@ namespace nd4j {
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory;
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
}
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory}});
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
}

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
using namespace dnnl;
namespace nd4j {
namespace ops {
@ -89,11 +89,11 @@ namespace nd4j {
auto poolingMode = PoolingType::AVG_POOL;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
@ -109,20 +109,20 @@ namespace nd4j {
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = mkldnn::memory(user_src_md, engine, gradI->buffer());
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory;
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
}
auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
}
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
}

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
using namespace dnnl;
namespace nd4j {
namespace ops {
@ -86,11 +86,11 @@ namespace nd4j {
auto poolingMode = PoolingType::AVG_POOL;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
@ -102,21 +102,21 @@ namespace nd4j {
pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
}
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory}});
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
}

View File

@ -26,7 +26,7 @@
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
using namespace dnnl;
namespace nd4j {
namespace ops {
@ -92,11 +92,11 @@ namespace nd4j {
auto poolingMode = PoolingType::AVG_POOL;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
@ -111,24 +111,24 @@ namespace nd4j {
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory;
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
}
auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
}
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
}

View File

@ -39,7 +39,7 @@ namespace platforms {
//////////////////////////////////////////////////////////////////////////
static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, const float epsilon, NDArray* z) {
// unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any)
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any)
// also it gives wrong results for formats nhwc and ndhwc
// x -> 2D:nc, 4D:nchw, 5D:ncdhw
@ -53,35 +53,35 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// input type
mkldnn::memory::data_type type = mkldnn::memory::data_type::f32;
dnnl::memory::data_type type = dnnl::memory::data_type::f32;
// indicate whether gamma or/and beta are given
auto flags = mkldnn::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch
auto flags = dnnl::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch
if (weights != nullptr)
flags |= mkldnn::normalization_flags::use_scale_shift;
flags |= dnnl::normalization_flags::use_scale_shift;
mkldnn::memory::dims dims;
mkldnn::memory::format_tag format;
dnnl::memory::dims dims;
dnnl::memory::format_tag format;
if(xRank == 2) {
dims = {x->sizeAt(0), x->sizeAt(1)};
format = mkldnn::memory::format_tag::nc;
format = dnnl::memory::format_tag::nc;
}
else if(xRank == 4) {
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
format = mkldnn::memory::format_tag::nchw;
format = dnnl::memory::format_tag::nchw;
}
else { // xRank = 5
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
format = mkldnn::memory::format_tag::ncdhw;
format = dnnl::memory::format_tag::ncdhw;
}
// memory descriptors for arrays
// x
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
if(xRank > 2) {
@ -92,9 +92,9 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
// z, output
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(dims, type, format);
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(dims, type, format);
z_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, format);
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
z_user_md.data.format_kind = dnnl_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = z->stridesOf()[0];
z_user_md.data.format_desc.blocking.strides[1] = z->stridesOf()[1];
if(xRank > 2) {
@ -106,53 +106,53 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
// batchnorm forward description
mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
std::unordered_map<int, dnnl::memory> args;
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
// provide memory and check whether reorder is required
// x
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer());
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem;
auto x_mkl_mem = xReorder ? dnnl::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem;
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// z
auto z_user_mem = mkldnn::memory(z_user_md, engine, z->getBuffer());
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? mkldnn::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem;
auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem;
if (zReorder)
mkldnn::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem);
args[MKLDNN_ARG_DST] = z_mkl_mem;
dnnl::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem);
args[DNNL_ARG_DST] = z_mkl_mem;
// mean
auto mean_mkl_mem = mkldnn::memory(op_ff_prim_desc.mean_desc(), engine, mean->getBuffer());
args[MKLDNN_ARG_MEAN] = mean_mkl_mem;
auto mean_mkl_mem = dnnl::memory(op_ff_prim_desc.mean_desc(), engine, mean->getBuffer());
args[DNNL_ARG_MEAN] = mean_mkl_mem;
// variance
auto var_mkl_mem = mkldnn::memory(op_ff_prim_desc.variance_desc(), engine, variance->getBuffer());
args[MKLDNN_ARG_VARIANCE] = var_mkl_mem;
auto var_mkl_mem = dnnl::memory(op_ff_prim_desc.variance_desc(), engine, variance->getBuffer());
args[DNNL_ARG_VARIANCE] = var_mkl_mem;
// gamma and beta (and their gradients) if they are present
if(weights != nullptr) {
auto w_mkl_mem = mkldnn::memory(op_ff_prim_desc.weights_desc(), engine, weights->getBuffer());
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
auto w_mkl_mem = dnnl::memory(op_ff_prim_desc.weights_desc(), engine, weights->getBuffer());
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
}
// run calculations
mkldnn::batch_normalization_forward(op_ff_prim_desc).execute(stream, args);
dnnl::batch_normalization_forward(op_ff_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
stream.wait();
@ -164,7 +164,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights,
const float epsilon, NDArray* dLdI, NDArray* dLdW) {
// unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any)
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any)
// also it gives wrong results for formats nhwc and ndhwc
// x -> 2D:nc, 4D:nchw, 5D:ncdhw
@ -180,35 +180,35 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// input type
mkldnn::memory::data_type type = mkldnn::memory::data_type::f32;
dnnl::memory::data_type type = dnnl::memory::data_type::f32;
// indicate whether gamma or/and beta are given
auto flags = mkldnn::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch
auto flags = dnnl::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch
if (weights != nullptr)
flags |= mkldnn::normalization_flags::use_scale_shift;
flags |= dnnl::normalization_flags::use_scale_shift;
mkldnn::memory::dims dims;
mkldnn::memory::format_tag format;
dnnl::memory::dims dims;
dnnl::memory::format_tag format;
if(xRank == 2) {
dims = {x->sizeAt(0), x->sizeAt(1)};
format = mkldnn::memory::format_tag::nc;
format = dnnl::memory::format_tag::nc;
}
else if(xRank == 4) {
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
format = mkldnn::memory::format_tag::nchw;
format = dnnl::memory::format_tag::nchw;
}
else { // xRank = 5
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
format = mkldnn::memory::format_tag::ncdhw;
format = dnnl::memory::format_tag::ncdhw;
}
// memory descriptors for arrays
// x
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
if(xRank > 2) {
@ -219,9 +219,9 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
// dLdO
mkldnn::memory::desc dLdO_mkl_md = mkldnn::memory::desc(dims, type, format);
mkldnn::memory::desc dLdO_user_md = mkldnn::memory::desc(dims, type, format);
dLdO_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, format);
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format
dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->stridesOf()[0];
dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->stridesOf()[1];
if(xRank > 2) {
@ -232,9 +232,9 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->stridesOf()[4];
// dLdI
mkldnn::memory::desc dLdI_mkl_md = mkldnn::memory::desc(dims, type, format);
mkldnn::memory::desc dLdI_user_md = mkldnn::memory::desc(dims, type, format);
dLdI_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, format);
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format
dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->stridesOf()[0];
dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->stridesOf()[1];
if(xRank > 2) {
@ -245,66 +245,66 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->stridesOf()[4];
// batchnorm forward description
mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// batchnorm backprop description
mkldnn::batch_normalization_backward::desc op_bp_desc(mkldnn::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags);
mkldnn::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc);
dnnl::batch_normalization_backward::desc op_bp_desc(dnnl::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags);
dnnl::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
std::unordered_map<int, dnnl::memory> args;
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
// provide memory and check whether reorder is required
// x
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer());
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
const bool xReorder = op_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem;
auto x_mkl_mem = xReorder ? dnnl::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem;
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// dLdO
auto dLdO_user_mem = mkldnn::memory(dLdO_user_md, engine, dLdO->getBuffer());
auto dLdO_user_mem = dnnl::memory(dLdO_user_md, engine, dLdO->getBuffer());
const bool dLdOReorder = op_bp_prim_desc.diff_dst_desc() != dLdO_user_mem.get_desc();
auto dLdO_mkl_mem = dLdOReorder ? mkldnn::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdO_user_mem;
auto dLdO_mkl_mem = dLdOReorder ? dnnl::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdO_user_mem;
if (dLdOReorder)
mkldnn::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem);
args[MKLDNN_ARG_DIFF_DST] = dLdO_mkl_mem;
dnnl::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem);
args[DNNL_ARG_DIFF_DST] = dLdO_mkl_mem;
// mean
auto mean_mkl_mem = mkldnn::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());
args[MKLDNN_ARG_MEAN] = mean_mkl_mem;
auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());
args[DNNL_ARG_MEAN] = mean_mkl_mem;
// variance
auto var_mkl_mem = mkldnn::memory(op_bp_prim_desc.variance_desc(), engine, variance->getBuffer());
args[MKLDNN_ARG_VARIANCE] = var_mkl_mem;
auto var_mkl_mem = dnnl::memory(op_bp_prim_desc.variance_desc(), engine, variance->getBuffer());
args[DNNL_ARG_VARIANCE] = var_mkl_mem;
// dLdI
auto dLdI_user_mem = mkldnn::memory(dLdI_user_md, engine, dLdI->getBuffer());
auto dLdI_user_mem = dnnl::memory(dLdI_user_md, engine, dLdI->getBuffer());
const bool dLdIReorder = op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc();
auto dLdI_mkl_mem = dLdIReorder ? mkldnn::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem;
args[MKLDNN_ARG_DIFF_SRC] = dLdI_mkl_mem;
auto dLdI_mkl_mem = dLdIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem;
args[DNNL_ARG_DIFF_SRC] = dLdI_mkl_mem;
// gamma and beta (and their gradients) if they are present
if(weights != nullptr) {
auto w_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, weights->getBuffer());
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
auto w_mkl_mem = dnnl::memory(op_bp_prim_desc.weights_desc(), engine, weights->getBuffer());
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
auto dLdW_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->getBuffer());
args[MKLDNN_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem;
auto dLdW_mkl_mem = dnnl::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->getBuffer());
args[DNNL_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem;
}
// run calculations
mkldnn::batch_normalization_backward(op_bp_prim_desc).execute(stream, args);
dnnl::batch_normalization_backward(op_bp_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (dLdIReorder)
mkldnn::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem);
dnnl::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem);
stream.wait();
@ -532,37 +532,37 @@ PLATFORM_CHECK(batchnorm) {
// weights({1, 2, 0, 0}).assign(0.0f);
// mkldnn_memory_desc_t empty;
// mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty);
// dnnl::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty);
// auto flag = mkldnn::normalization_flags::use_global_stats;
// auto flag = dnnl::normalization_flags::use_global_stats;
// if (applyScale || applyOffset)
// flag |= mkldnn::normalization_flags::use_scale_shift;
// flag |= dnnl::normalization_flags::use_scale_shift;
// mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output,
// &batchnorm_src_md, nullptr, &batchnorm_dst_md,
// &user_src_md, nullptr, &user_dst_md, axes[0]);
// auto batchnorm_desc = mkldnn::batch_normalization_forward::desc(mkldnn::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag);
// auto batchnorm_desc = dnnl::batch_normalization_forward::desc(dnnl::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag);
// auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// mkldnn::stream stream(engine);
// auto batchnorm_prim_desc = mkldnn::batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
// auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
// auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
// auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine,
// dnnl::stream stream(engine);
// auto batchnorm_prim_desc = dnnl::batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
// auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
// auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
// auto batchnorm_mean_memory = dnnl::memory(batchnorm_prim_desc.mean_desc(), engine,
// mean->buffer());
// auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine,
// auto batchnorm_variance_memory = dnnl::memory(batchnorm_prim_desc.variance_desc(), engine,
// variance->buffer());
// auto batchnorm_src_memory = user_src_memory;
// mkldnn::memory m(batchnorm_src_md, engine);
// dnnl::memory m(batchnorm_src_md, engine);
// if (m.get_desc() != user_src_memory.get_desc()) {
// batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine);
// mkldnn::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory,
// batchnorm_src_memory = dnnl::memory(batchnorm_src_md, engine);
// dnnl::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory,
// batchnorm_src_memory);
// }
// auto batchnorm_dst_memory = user_dst_memory;
// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
// batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine);
// batchnorm_dst_memory = dnnl::memory(batchnorm_prim_desc.dst_desc(), engine);
// }
// if (applyScale || applyOffset) {
// if (gamma != nullptr) {
@ -572,22 +572,22 @@ PLATFORM_CHECK(batchnorm) {
// weights({1, 2, 0, 0}).assign(beta);
// }
// auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer());
// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
// auto batchnorm_weights_memory = dnnl::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer());
// dnnl::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
// {{MKLDNN_ARG_SRC, batchnorm_src_memory},
// {MKLDNN_ARG_MEAN, batchnorm_mean_memory},
// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
// {MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory},
// {MKLDNN_ARG_DST, batchnorm_dst_memory}});
// } else {
// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
// dnnl::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
// {{MKLDNN_ARG_SRC, batchnorm_src_memory},
// {MKLDNN_ARG_MEAN, batchnorm_mean_memory},
// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
// {MKLDNN_ARG_DST, batchnorm_dst_memory}});
// }
// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
// mkldnn::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory,
// dnnl::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory,
// user_dst_memory);
// }
// stream.wait();

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
using namespace dnnl;
namespace nd4j {
namespace ops {
@ -47,12 +47,12 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
if(isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
dnnl_memory_desc_t empty;
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
empty);
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
bias, output,
@ -74,38 +74,38 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto user_weights_memory = mkldnn::memory(user_weights_md, engine,
auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto user_weights_memory = dnnl::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto conv_src_memory = user_src_memory;
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine);
conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine);
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
}
auto conv_weights_memory = user_weights_memory;
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine);
conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine);
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
conv_weights_memory);
}
auto conv_dst_memory = user_dst_memory;
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine);
conv_dst_memory = dnnl::memory(conv_prim_desc.dst_desc(), engine);
}
if (bias != nullptr) {
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine,
auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine,
const_cast<NDArray *>(bias)->buffer());
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_BIAS, conv_bias_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
{DNNL_ARG_WEIGHTS, conv_weights_memory},
{DNNL_ARG_BIAS, conv_bias_memory},
{DNNL_ARG_DST, conv_dst_memory}});
} else {
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
{DNNL_ARG_WEIGHTS, conv_weights_memory},
{DNNL_ARG_DST, conv_dst_memory}});
}
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
@ -198,12 +198,12 @@ PLATFORM_IMPL(conv2d_bp) {
if (isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
dnnl_memory_desc_t empty;
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
gradB, gradO,
@ -235,47 +235,47 @@ PLATFORM_IMPL(conv2d_bp) {
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
conv_prim_desc);
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
auto userW_src_memory = dnnl::memory(user_src_md, engine,
const_cast<NDArray *>(input)->buffer());
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = dnnl::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convW_src_memory = userW_src_memory;
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
convW_src_memory);
}
auto convW_weights_memory = userW_weights_memory;
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine);
}
auto convW_dst_memory = userW_dst_memory;
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
convW_dst_memory);
}
if (gradB != nullptr) {
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine,
gradB->buffer());
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
{{DNNL_ARG_SRC, convW_src_memory},
{DNNL_ARG_DIFF_DST, convW_dst_memory},
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
{DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
} else {
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
{{DNNL_ARG_SRC, convW_src_memory},
{DNNL_ARG_DIFF_DST, convW_dst_memory},
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
}
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
@ -293,38 +293,38 @@ PLATFORM_IMPL(conv2d_bp) {
conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
conv_prim_desc);
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = dnnl::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
auto userI_dst_memory = dnnl::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convI_src_memory = userI_src_memory;
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
}
auto convI_weights_memory = userI_weights_memory;
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
convI_weights_memory);
}
auto convI_dst_memory = userI_dst_memory;
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
convI_dst_memory);
}
convolution_backward_data(convI_prim_desc).execute(stream,
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
{{DNNL_ARG_DIFF_DST, convI_dst_memory},
{DNNL_ARG_WEIGHTS, convI_weights_memory},
{DNNL_ARG_DIFF_SRC, convI_src_memory}});
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
using namespace dnnl;
namespace nd4j {
namespace ops {
@ -83,12 +83,12 @@ PLATFORM_IMPL(conv3dnew) {
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
dnnl_memory_desc_t empty;
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
empty);
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
isNCDHW,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights,
@ -110,37 +110,37 @@ PLATFORM_IMPL(conv3dnew) {
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto user_weights_memory = mkldnn::memory(user_weights_md, engine,
auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto user_weights_memory = dnnl::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto conv_src_memory = user_src_memory;
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine);
conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine);
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
}
auto conv_weights_memory = user_weights_memory;
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine);
conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine);
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
conv_weights_memory);
}
auto conv_dst_memory = user_dst_memory;
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine);
conv_dst_memory = dnnl::memory(conv_prim_desc.dst_desc(), engine);
}
if (bias != nullptr) {
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine, bias->buffer());
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_BIAS, conv_bias_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->buffer());
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
{DNNL_ARG_WEIGHTS, conv_weights_memory},
{DNNL_ARG_BIAS, conv_bias_memory},
{DNNL_ARG_DST, conv_dst_memory}});
} else {
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
{DNNL_ARG_WEIGHTS, conv_weights_memory},
{DNNL_ARG_DST, conv_dst_memory}});
}
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
@ -235,12 +235,12 @@ PLATFORM_IMPL(conv3dnew_bp) {
oC, bias->rankOf(), bias->lengthOf());
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
dnnl_memory_desc_t empty;
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
isNDHWC,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights,
@ -273,47 +273,47 @@ PLATFORM_IMPL(conv3dnew_bp) {
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
conv_prim_desc);
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
auto userW_src_memory = dnnl::memory(user_src_md, engine,
const_cast<NDArray *>(input)->buffer());
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = dnnl::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convW_src_memory = userW_src_memory;
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
convW_src_memory);
}
auto convW_weights_memory = userW_weights_memory;
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine);
}
auto convW_dst_memory = userW_dst_memory;
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
convW_dst_memory);
}
if (gradB != nullptr) {
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine,
gradB->buffer());
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
{{DNNL_ARG_SRC, convW_src_memory},
{DNNL_ARG_DIFF_DST, convW_dst_memory},
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
{DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
} else {
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
{{DNNL_ARG_SRC, convW_src_memory},
{DNNL_ARG_DIFF_DST, convW_dst_memory},
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
}
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
@ -330,38 +330,38 @@ PLATFORM_IMPL(conv3dnew_bp) {
conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
conv_prim_desc);
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = dnnl::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
auto userI_dst_memory = dnnl::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convI_src_memory = userI_src_memory;
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
}
auto convI_weights_memory = userI_weights_memory;
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
convI_weights_memory);
}
auto convI_dst_memory = userI_dst_memory;
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
convI_dst_memory);
}
convolution_backward_data(convI_prim_desc).execute(stream,
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
{{DNNL_ARG_DIFF_DST, convI_dst_memory},
{DNNL_ARG_WEIGHTS, convI_weights_memory},
{DNNL_ARG_DIFF_SRC, convI_src_memory}});
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,

View File

@ -49,77 +49,77 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
mkldnn::memory::dims strides = { sH, sW };
mkldnn::memory::dims padding = { pH, pW };
mkldnn::memory::dims padding_r = { pHmkl, pWmkl };
mkldnn::memory::dims dilation = { dHmkl, dWmkl };
dnnl::memory::dims strides = { sH, sW };
dnnl::memory::dims padding = { pH, pW };
dnnl::memory::dims padding_r = { pHmkl, pWmkl };
dnnl::memory::dims dilation = { dHmkl, dWmkl };
// input type
mkldnn::memory::data_type xType;
dnnl::memory::data_type xType;
if(input->dataType() == DataType::FLOAT32)
xType = mkldnn::memory::data_type::f32;
xType = dnnl::memory::data_type::f32;
else if(input->dataType() == DataType::HALF)
xType = mkldnn::memory::data_type::f16;
xType = dnnl::memory::data_type::f16;
else if(input->dataType() == DataType::UINT8)
xType = mkldnn::memory::data_type::u8;
xType = dnnl::memory::data_type::u8;
else
xType = mkldnn::memory::data_type::s8;
xType = dnnl::memory::data_type::s8;
// weights type
mkldnn::memory::data_type wType = xType;
if(xType == mkldnn::memory::data_type::u8)
wType = mkldnn::memory::data_type::s8;
dnnl::memory::data_type wType = xType;
if(xType == dnnl::memory::data_type::u8)
wType = dnnl::memory::data_type::s8;
// output and bias type (have the same types)
mkldnn::memory::data_type zType;
dnnl::memory::data_type zType;
if(output->dataType() == DataType::FLOAT32)
zType = mkldnn::memory::data_type::f32;
zType = dnnl::memory::data_type::f32;
else if(output->dataType() == DataType::HALF)
zType = mkldnn::memory::data_type::f16;
zType = dnnl::memory::data_type::f16;
else if(output->dataType() == DataType::UINT8)
zType = mkldnn::memory::data_type::u8;
zType = dnnl::memory::data_type::u8;
else if(output->dataType() == DataType::INT8)
zType = mkldnn::memory::data_type::s8;
zType = dnnl::memory::data_type::s8;
else
zType = mkldnn::memory::data_type::s32;
zType = dnnl::memory::data_type::s32;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw;
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
mkldnn::memory::dims xDims = {bS, iC, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oH, oW};
dnnl::memory::dims xDims = {bS, iC, iH, iW};
dnnl::memory::dims wDims = {oC, iC, kH, kW};
dnnl::memory::dims zDims = {bS, oC, oH, oW};
// memory descriptors for arrays
// input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
// weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
// bias
mkldnn::memory::desc b_mkl_md;
dnnl::memory::desc b_mkl_md;
if(bias != nullptr)
b_mkl_md = mkldnn::memory::desc({oC}, zType, mkldnn::memory::format_tag::x);
b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x);
// output
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(zDims, zType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(zDims, zType, xFormat);
z_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
z_user_md.data.format_kind = dnnl_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
@ -128,51 +128,51 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// operation primitive description
mkldnn::deconvolution_forward::desc op_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct,
dnnl::deconvolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct,
x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine);
dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
std::unordered_map<int, dnnl::memory> args;
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
// provide memory buffers and check whether reorder is required
// input
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem;
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// bias
if(bias != nullptr) {
auto b_mkl_mem = mkldnn::memory(b_mkl_md, engine, bias->getBuffer());
args[MKLDNN_ARG_BIAS] = b_mkl_mem;
auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer());
args[DNNL_ARG_BIAS] = b_mkl_mem;
}
// output
auto z_user_mem = mkldnn::memory(z_user_md, engine, output->getBuffer());
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? mkldnn::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[MKLDNN_ARG_DST] = z_mkl_mem;
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
// run calculations
mkldnn::deconvolution_forward(op_prim_desc).execute(stream, args);
dnnl::deconvolution_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
stream.wait();
@ -196,157 +196,157 @@ static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
mkldnn::memory::dims strides = { sH, sW };
mkldnn::memory::dims padding = { pH, pW };
mkldnn::memory::dims padding_r = { pHmkl, pWmkl };
mkldnn::memory::dims dilation = { dHmkl, dWmkl };
dnnl::memory::dims strides = { sH, sW };
dnnl::memory::dims padding = { pH, pW };
dnnl::memory::dims padding_r = { pHmkl, pWmkl };
dnnl::memory::dims dilation = { dHmkl, dWmkl };
// input type
mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// weights type
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradO type
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradI type
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradW type
mkldnn::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradB type
mkldnn::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16) : mkldnn::memory::data_type::f32;
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw;
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
mkldnn::memory::dims xDims = {bS, iC, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oH, oW};
dnnl::memory::dims xDims = {bS, iC, iH, iW};
dnnl::memory::dims wDims = {oC, iC, kH, kW};
dnnl::memory::dims zDims = {bS, oC, oH, oW};
// memory descriptors for arrays
// input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
// weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
// gradO
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat);
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
// gradI
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat);
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
// gradW
mkldnn::memory::desc gradW_mkl_md = mkldnn::memory::desc(wDims, gradWType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc gradW_user_md = mkldnn::memory::desc(wDims, gradWType, wFormat);
gradW_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3];
// gradB
mkldnn::memory::desc gradB_mkl_md;
dnnl::memory::desc gradB_mkl_md;
if(gradB != nullptr)
gradB_mkl_md = mkldnn::memory::desc({oC}, gradBType, mkldnn::memory::format_tag::x);
gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// forward primitive description
mkldnn::deconvolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
dnnl::deconvolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// backward data primitive description
mkldnn::deconvolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
dnnl::deconvolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
// backward weights primitive description
mkldnn::deconvolution_backward_weights::desc op_weights_bp_desc(mkldnn::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
dnnl::deconvolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
std::unordered_map<int, dnnl::memory> args;
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
// provide memory buffers and check whether reorder is required
// input
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem;
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// gradO
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer());
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
if (gradOReorder)
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem;
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
// gradI
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer());
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem;
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
// gradW
auto gradW_user_mem = mkldnn::memory(gradW_user_md, engine, gradW->getBuffer());
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer());
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
auto gradW_mkl_mem = gradWReorder ? mkldnn::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
args[MKLDNN_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
// gradB
if(gradB != nullptr) {
auto gradB_mkl_mem = mkldnn::memory(gradB_mkl_md, engine, gradB->getBuffer());
args[MKLDNN_ARG_DIFF_BIAS] = gradB_mkl_mem;
auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer());
args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem;
}
// run backward data calculations
mkldnn::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
// run backward weights calculations
mkldnn::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary
if (gradIReorder)
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
if (gradWReorder)
mkldnn::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
stream.wait();

View File

@ -39,52 +39,52 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC]
// gradO [bS, oH, oW, oC]
mkldnn::memory::dims strides = { sH, sW };
mkldnn::memory::dims dilation = { dH - 1, dW - 1 };
mkldnn::memory::dims padding = { pH, pW };
mkldnn::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
dnnl::memory::dims strides = { sH, sW };
dnnl::memory::dims dilation = { dH - 1, dW - 1 };
dnnl::memory::dims padding = { pH, pW };
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
// weights type
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradO type
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradI type
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw;
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
mkldnn::memory::dims xDims = {bS, iC, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oH, oW};
dnnl::memory::dims xDims = {bS, iC, iH, iW};
dnnl::memory::dims wDims = {oC, iC, kH, kW};
dnnl::memory::dims zDims = {bS, oC, oH, oW};
// memory descriptors for arrays
// input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, gradOType, mkldnn::memory::format_tag::any);
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, gradOType, dnnl::memory::format_tag::any);
// weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
// gradO
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat);
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
// gradI
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat);
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
@ -94,48 +94,48 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// forward primitive description
mkldnn::convolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// backward data primitive description
mkldnn::convolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
std::unordered_map<int, dnnl::memory> args;
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
// provide memory buffers and check whether reorder is required
// weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// gradO
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer());
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
if (gradOReorder)
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem;
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
// gradI
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer());
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem;
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
// run backward data calculations
mkldnn::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary
if (gradIReorder)
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
stream.wait();

View File

@ -50,54 +50,54 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
mkldnn::memory::dims strides = { sD, sH, sW };
mkldnn::memory::dims padding = { pD, pH, pW };
mkldnn::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
mkldnn::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
dnnl::memory::dims strides = { sD, sH, sW };
dnnl::memory::dims padding = { pD, pH, pW };
dnnl::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
dnnl::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
// input type
mkldnn::memory::data_type xType;
dnnl::memory::data_type xType;
if(input->dataType() == DataType::FLOAT32)
xType = mkldnn::memory::data_type::f32;
xType = dnnl::memory::data_type::f32;
else if(input->dataType() == DataType::HALF)
xType = mkldnn::memory::data_type::f16;
xType = dnnl::memory::data_type::f16;
else if(input->dataType() == DataType::UINT8)
xType = mkldnn::memory::data_type::u8;
xType = dnnl::memory::data_type::u8;
else
xType = mkldnn::memory::data_type::s8;
xType = dnnl::memory::data_type::s8;
// weights type
mkldnn::memory::data_type wType = xType;
if(xType == mkldnn::memory::data_type::u8)
wType = mkldnn::memory::data_type::s8;
dnnl::memory::data_type wType = xType;
if(xType == dnnl::memory::data_type::u8)
wType = dnnl::memory::data_type::s8;
// output and bias type (have the same types)
mkldnn::memory::data_type zType;
dnnl::memory::data_type zType;
if(output->dataType() == DataType::FLOAT32)
zType = mkldnn::memory::data_type::f32;
zType = dnnl::memory::data_type::f32;
else if(output->dataType() == DataType::HALF)
zType = mkldnn::memory::data_type::f16;
zType = dnnl::memory::data_type::f16;
else if(output->dataType() == DataType::UINT8)
zType = mkldnn::memory::data_type::u8;
zType = dnnl::memory::data_type::u8;
else if(output->dataType() == DataType::INT8)
zType = mkldnn::memory::data_type::s8;
zType = dnnl::memory::data_type::s8;
else
zType = mkldnn::memory::data_type::s32;
zType = dnnl::memory::data_type::s32;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::ncdhw;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oidhw;
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
mkldnn::memory::dims xDims = {bS, iC, iD, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kD, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oD, oH, oW};
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
// memory descriptors for arrays
// input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
@ -105,9 +105,9 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
// weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
@ -115,14 +115,14 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
// bias
mkldnn::memory::desc b_mkl_md;
dnnl::memory::desc b_mkl_md;
if(bias != nullptr)
b_mkl_md = mkldnn::memory::desc({oC}, zType, mkldnn::memory::format_tag::x);
b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x);
// output
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(zDims, zType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(zDims, zType, xFormat);
z_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
z_user_md.data.format_kind = dnnl_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
@ -132,51 +132,51 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// operation primitive description
mkldnn::deconvolution_forward::desc op_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct,
dnnl::deconvolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct,
x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine);
dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
std::unordered_map<int, dnnl::memory> args;
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
// provide memory buffers and check whether reorder is required
// input
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem;
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// bias
if(bias != nullptr) {
auto b_mkl_mem = mkldnn::memory(b_mkl_md, engine, bias->getBuffer());
args[MKLDNN_ARG_BIAS] = b_mkl_mem;
auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer());
args[DNNL_ARG_BIAS] = b_mkl_mem;
}
// output
auto z_user_mem = mkldnn::memory(z_user_md, engine, output->getBuffer());
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? mkldnn::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[MKLDNN_ARG_DST] = z_mkl_mem;
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
// run calculations
mkldnn::deconvolution_forward(op_prim_desc).execute(stream, args);
dnnl::deconvolution_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
stream.wait();
@ -200,37 +200,37 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
mkldnn::memory::dims strides = { sD, sH, sW };
mkldnn::memory::dims padding = { pD, pH, pW };
mkldnn::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
mkldnn::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
dnnl::memory::dims strides = { sD, sH, sW };
dnnl::memory::dims padding = { pD, pH, pW };
dnnl::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
dnnl::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
// input type
mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// weights type
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradO type
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradI type
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradW type
mkldnn::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// gradB type
mkldnn::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16) : mkldnn::memory::data_type::f32;
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::ncdhw; // isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oidhw;
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw; // isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
mkldnn::memory::dims xDims = {bS, iC, iD, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kD, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oD, oH, oW};
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
// memory descriptors for arrays
// input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
@ -238,9 +238,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
// weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
@ -248,9 +248,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
// gradO
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat);
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
@ -258,9 +258,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->stridesOf()[4];
// gradI
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat);
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
@ -268,9 +268,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->stridesOf()[4];
// gradW
mkldnn::memory::desc gradW_mkl_md = mkldnn::memory::desc(wDims, gradWType, wFormat);
mkldnn::memory::desc gradW_user_md = mkldnn::memory::desc(wDims, gradWType, wFormat);
gradW_user_md.data.format_kind = mkldnn_blocked; // overrides format
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat);
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
@ -278,85 +278,85 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->stridesOf()[4];
// gradB
mkldnn::memory::desc gradB_mkl_md;
dnnl::memory::desc gradB_mkl_md;
if(gradB != nullptr)
gradB_mkl_md = mkldnn::memory::desc({oC}, gradBType, mkldnn::memory::format_tag::x);
gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// forward primitive description
mkldnn::deconvolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
dnnl::deconvolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// backward data primitive description
mkldnn::deconvolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
dnnl::deconvolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
// backward weights primitive description
mkldnn::deconvolution_backward_weights::desc op_weights_bp_desc(mkldnn::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
dnnl::deconvolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
std::unordered_map<int, dnnl::memory> args;
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
// provide memory buffers and check whether reorder is required
// input
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem;
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// gradO
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer());
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
if (gradOReorder)
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem;
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
// gradI
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer());
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem;
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
// gradW
auto gradW_user_mem = mkldnn::memory(gradW_user_md, engine, gradW->getBuffer());
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer());
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
auto gradW_mkl_mem = gradWReorder ? mkldnn::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
args[MKLDNN_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
// gradB
if(gradB != nullptr) {
auto gradB_mkl_mem = mkldnn::memory(gradB_mkl_md, engine, gradB->getBuffer());
args[MKLDNN_ARG_DIFF_BIAS] = gradB_mkl_mem;
auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer());
args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem;
}
// run backward data calculations
mkldnn::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
// run backward weights calculations
mkldnn::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary
if (gradIReorder)
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
if (gradWReorder)
mkldnn::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
stream.wait();

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
using namespace dnnl;
namespace nd4j {
namespace ops {
@ -44,8 +44,8 @@ namespace nd4j {
double bias = T_ARG(0);
int depth = INT_ARG(0);
mkldnn_memory_desc_t empty;
mkldnn::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty);
dnnl_memory_desc_t empty;
dnnl::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty);
mkldnnUtils::getMKLDNNMemoryDescLrn(input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md,
&user_src_md, nullptr, &user_dst_md, input->rankOf() - 1);
@ -54,24 +54,24 @@ namespace nd4j {
lrn_src_md, (2 * depth + 1), alpha * (2 * depth + 1), beta, bias);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto lrn_src_memory = user_src_memory;
if (lrn_prim_desc.src_desc() != user_src_memory.get_desc()) {
lrn_src_memory = mkldnn::memory(lrn_prim_desc.src_desc(), engine);
lrn_src_memory = dnnl::memory(lrn_prim_desc.src_desc(), engine);
reorder(user_src_memory, lrn_src_memory).execute(stream, user_src_memory, lrn_src_memory);
}
auto lrn_dst_memory = user_dst_memory;
if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
lrn_dst_memory = mkldnn::memory(lrn_prim_desc.dst_desc(), engine);
lrn_dst_memory = dnnl::memory(lrn_prim_desc.dst_desc(), engine);
}
lrn_forward(lrn_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, lrn_src_memory},
{MKLDNN_ARG_DST, lrn_dst_memory}});
lrn_forward(lrn_prim_desc).execute(stream, {{DNNL_ARG_SRC, lrn_src_memory},
{DNNL_ARG_DST, lrn_dst_memory}});
if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(lrn_dst_memory, user_dst_memory).execute(stream, lrn_dst_memory, user_dst_memory);

View File

@ -21,7 +21,7 @@
#include <ops/declarable/OpRegistrator.h>
#include "mkldnnUtils.h"
using namespace mkldnn;
using namespace dnnl;
namespace nd4j {
namespace ops {
@ -132,52 +132,52 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md,
dnnl::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md,
x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md;
// input type
mkldnn::memory::data_type xType;
dnnl::memory::data_type xType;
if(x->dataType() == DataType::FLOAT32)
xType = mkldnn::memory::data_type::f32;
xType = dnnl::memory::data_type::f32;
else if(x->dataType() == DataType::HALF)
xType = mkldnn::memory::data_type::f16;
xType = dnnl::memory::data_type::f16;
else
xType = mkldnn::memory::data_type::u8;
xType = dnnl::memory::data_type::u8;
// weights type
mkldnn::memory::data_type wType = xType;
if(xType == mkldnn::memory::data_type::u8)
wType = mkldnn::memory::data_type::s8;
dnnl::memory::data_type wType = xType;
if(xType == dnnl::memory::data_type::u8)
wType = dnnl::memory::data_type::s8;
// bias type
mkldnn::memory::data_type bType = xType;
if(xType == mkldnn::memory::data_type::u8)
bType = mkldnn::memory::data_type::f32;
dnnl::memory::data_type bType = xType;
if(xType == dnnl::memory::data_type::u8)
bType = dnnl::memory::data_type::f32;
// output type
mkldnn::memory::data_type hType;
dnnl::memory::data_type hType;
if(h->dataType() == DataType::FLOAT32)
hType = mkldnn::memory::data_type::f32;
hType = dnnl::memory::data_type::f32;
else if(h->dataType() == DataType::HALF)
hType = mkldnn::memory::data_type::f16;
hType = dnnl::memory::data_type::f16;
else
hType = mkldnn::memory::data_type::u8;
hType = dnnl::memory::data_type::u8;
// memory descriptors for arrays
// x
x_lstm_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::any);
// x_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, nIn}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, nIn}, type, mkldnn::memory::format_tag::ntc);
x_user_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::tnc);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
x_lstm_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any);
// x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, dnnl::memory::format_tag::ntc);
x_user_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
// wx
wx_lstm_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::any);
wx_user_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::ldigo);
wx_user_md.data.format_kind = mkldnn_blocked; // overrides format
wx_lstm_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::any);
wx_user_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::ldigo);
wx_user_md.data.format_kind = dnnl_blocked; // overrides format
wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0];
wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1];
wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2];
@ -185,9 +185,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4];
// wr
wr_lstm_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::any);
wr_user_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::ldigo);
wr_user_md.data.format_kind = mkldnn_blocked; // overrides format
wr_lstm_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::any);
wr_user_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::ldigo);
wr_user_md.data.format_kind = dnnl_blocked; // overrides format
wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0];
wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1];
wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2];
@ -195,19 +195,19 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4];
// h
h_lstm_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::any);
// h_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, hDirDim*nOut}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, hDirDim*nOut}, type, mkldnn::memory::format_tag::ntc);
h_user_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::tnc);
h_user_md.data.format_kind = mkldnn_blocked; // overrides format
h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::any);
// h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, hDirDim*nOut}, type, dnnl::memory::format_tag::ntc);
h_user_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::tnc);
h_user_md.data.format_kind = dnnl_blocked; // overrides format
h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0];
h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1];
h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2];
// b
if(b) {
b_lstm_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::any);
b_user_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::ldgo);
b_user_md.data.format_kind = mkldnn_blocked; // overrides format
b_lstm_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::any);
b_user_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::ldgo);
b_user_md.data.format_kind = dnnl_blocked; // overrides format
b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0];
b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1];
b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2];
@ -216,9 +216,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
// hI
if(hI) {
hI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any);
hI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc);
hI_user_md.data.format_kind = mkldnn_blocked; // overrides format
hI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any);
hI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc);
hI_user_md.data.format_kind = dnnl_blocked; // overrides format
hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0];
hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1];
hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2];
@ -227,9 +227,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
// cI
if(cI) {
cI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any);
cI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc);
cI_user_md.data.format_kind = mkldnn_blocked; // overrides format
cI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any);
cI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc);
cI_user_md.data.format_kind = dnnl_blocked; // overrides format
cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0];
cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1];
cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2];
@ -238,9 +238,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
// hL
if(hL) {
hL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::any);
hL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc);
hL_user_md.data.format_kind = mkldnn_blocked; // overrides format
hL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::any);
hL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
hL_user_md.data.format_kind = dnnl_blocked; // overrides format
hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0];
hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1];
hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2];
@ -248,9 +248,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
}
if(cL) {
cL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc);
cL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc);
cL_user_md.data.format_kind = mkldnn_blocked; // overrides format
cL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
cL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
cL_user_md.data.format_kind = dnnl_blocked; // overrides format
cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0];
cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1];
cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2];
@ -262,92 +262,92 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md,
h_lstm_md, hL_lstm_md, cL_lstm_md);
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
// lstm primitive description
lstm_forward::primitive_desc lstm_prim_desc(lstm_desc, engine);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
std::unordered_map<int, dnnl::memory> args;
// provide memory and check whether reorder is required
// x
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer());
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
const bool xReorder = lstm_prim_desc.src_layer_desc() != x_user_mem.get_desc();
auto x_lstm_mem = xReorder ? mkldnn::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem;
auto x_lstm_mem = xReorder ? dnnl::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem;
if (xReorder)
reorder(x_user_mem, x_lstm_mem).execute(stream, x_user_mem, x_lstm_mem);
args[MKLDNN_ARG_SRC_LAYER] = x_lstm_mem;
args[DNNL_ARG_SRC_LAYER] = x_lstm_mem;
// wx
auto wx_user_mem = mkldnn::memory(wx_user_md, engine, Wx->getBuffer());
auto wx_user_mem = dnnl::memory(wx_user_md, engine, Wx->getBuffer());
const bool wxReorder = lstm_prim_desc.weights_layer_desc()!= wx_user_mem.get_desc();
auto wx_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem;
auto wx_lstm_mem = wxReorder ? dnnl::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem;
if (wxReorder)
reorder(wx_user_mem, wx_lstm_mem).execute(stream, wx_user_mem, wx_lstm_mem);
args[MKLDNN_ARG_WEIGHTS_LAYER] = wx_lstm_mem;
args[DNNL_ARG_WEIGHTS_LAYER] = wx_lstm_mem;
// wr
auto wr_user_mem = mkldnn::memory(wr_user_md, engine, Wr->getBuffer());
auto wr_user_mem = dnnl::memory(wr_user_md, engine, Wr->getBuffer());
const bool wrReorder = lstm_prim_desc.weights_iter_desc() != wr_user_mem.get_desc();
auto wr_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem;
auto wr_lstm_mem = wxReorder ? dnnl::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem;
if (wrReorder)
reorder(wr_user_mem, wr_lstm_mem).execute(stream, wr_user_mem, wr_lstm_mem);
args[MKLDNN_ARG_WEIGHTS_ITER] = wr_lstm_mem;
args[DNNL_ARG_WEIGHTS_ITER] = wr_lstm_mem;
// h
auto h_user_mem = mkldnn::memory(h_user_md, engine, h->getBuffer());
auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer());
const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc();
auto h_lstm_mem = hReorder ? mkldnn::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem;
args[MKLDNN_ARG_DST_LAYER] = h_lstm_mem;
auto h_lstm_mem = hReorder ? dnnl::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem;
args[DNNL_ARG_DST_LAYER] = h_lstm_mem;
// b
if(b) {
auto b_user_mem = mkldnn::memory(b_user_md, engine, b->getBuffer());
auto b_user_mem = dnnl::memory(b_user_md, engine, b->getBuffer());
const bool bReorder = lstm_prim_desc.bias_desc() != b_user_mem.get_desc();
auto b_lstm_mem = bReorder ? mkldnn::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem;
auto b_lstm_mem = bReorder ? dnnl::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem;
if (bReorder)
reorder(b_user_mem, b_lstm_mem).execute(stream, b_user_mem, b_lstm_mem);
args[MKLDNN_ARG_BIAS] = b_lstm_mem;
args[DNNL_ARG_BIAS] = b_lstm_mem;
}
// hI
if(hI) {
auto hI_user_mem = mkldnn::memory(hI_user_md, engine, hI->getBuffer());
auto hI_user_mem = dnnl::memory(hI_user_md, engine, hI->getBuffer());
const bool hIReorder = lstm_prim_desc.src_iter_desc() != hI_user_mem.get_desc();
auto hI_lstm_mem = hIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem;
auto hI_lstm_mem = hIReorder ? dnnl::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem;
if (hIReorder)
reorder(hI_user_mem, hI_lstm_mem).execute(stream, hI_user_mem, hI_lstm_mem);
args[MKLDNN_ARG_SRC_ITER] = hI_lstm_mem;
args[DNNL_ARG_SRC_ITER] = hI_lstm_mem;
}
// cI
if(cI) {
auto cI_user_mem = mkldnn::memory(cI_user_md, engine, cI->getBuffer());
auto cI_user_mem = dnnl::memory(cI_user_md, engine, cI->getBuffer());
const bool cIReorder = lstm_prim_desc.src_iter_c_desc() != cI_user_mem.get_desc();
auto cI_lstm_mem = cIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem;
auto cI_lstm_mem = cIReorder ? dnnl::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem;
if (cIReorder)
reorder(cI_user_mem, cI_lstm_mem).execute(stream, cI_user_mem, cI_lstm_mem);
args[MKLDNN_ARG_SRC_ITER_C] = cI_lstm_mem;
args[DNNL_ARG_SRC_ITER_C] = cI_lstm_mem;
}
bool hLReorder(false), cLReorder(false);
mkldnn::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem;
dnnl::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem;
// hL
if(hL) {
hL_user_mem = mkldnn::memory(hL_user_md, engine, hL->getBuffer());
hL_user_mem = dnnl::memory(hL_user_md, engine, hL->getBuffer());
hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc();
hL_lstm_mem = hLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem;
args[MKLDNN_ARG_DST_ITER] = hL_lstm_mem;
hL_lstm_mem = hLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem;
args[DNNL_ARG_DST_ITER] = hL_lstm_mem;
}
// cL
if(cL) {
cL_user_mem = mkldnn::memory(cL_user_md, engine, cL->getBuffer());
cL_user_mem = dnnl::memory(cL_user_md, engine, cL->getBuffer());
cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc();
cL_lstm_mem = cLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem;
args[MKLDNN_ARG_DST_ITER_C] = cL_lstm_mem;
cL_lstm_mem = cLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem;
args[DNNL_ARG_DST_ITER_C] = cL_lstm_mem;
}
// run calculations

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
using namespace dnnl;
namespace nd4j {
namespace ops {
@ -82,11 +82,11 @@ namespace nd4j {
auto poolingMode = PoolingType::MAX_POOL;
int extraParam0 = 1;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
@ -102,23 +102,23 @@ namespace nd4j {
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory;
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
}
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory}});
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);

View File

@ -27,7 +27,7 @@
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
using namespace dnnl;
namespace nd4j {
namespace ops {
@ -89,11 +89,11 @@ namespace nd4j {
auto poolingMode = PoolingType::MAX_POOL;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
@ -109,44 +109,44 @@ namespace nd4j {
pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = mkldnn::memory(user_src_md, engine, gradI->buffer());
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory;
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
}
auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
}
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_desc(), engine);
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory},
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory}});
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory},
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
// probably wrong, fix that
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory},
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);

View File

@ -26,7 +26,7 @@
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
using namespace dnnl;
namespace nd4j {
namespace ops {
@ -87,11 +87,11 @@ namespace nd4j {
auto poolingMode = PoolingType::MAX_POOL;
auto extraParam0 = 1;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true,
@ -106,24 +106,24 @@ namespace nd4j {
pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
}
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory}});
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);

View File

@ -26,7 +26,7 @@
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
using namespace dnnl;
namespace nd4j {
namespace ops {
@ -93,11 +93,11 @@ namespace nd4j {
auto poolingMode = PoolingType::MAX_POOL;
auto extraParam0 = 1;
mkldnn_memory_desc_t empty;
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
mkldnn::algorithm algorithm;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true,
@ -115,44 +115,44 @@ namespace nd4j {
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory;
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
}
auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
}
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_desc(), engine);
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
{MKLDNN_ARG_DST, pool_dst_memory},
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory}});
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory},
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory},
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {

View File

@ -18,23 +18,23 @@
// @author saudet
//
#include <mkldnn_types.h>
#include <dnnl_types.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
using namespace dnnl;
namespace nd4j {
namespace mkldnnUtils {
void getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
mkldnn::memory::dims pool_src_tz = { bS, iC, iH, iW };
mkldnn::memory::dims pool_dst_tz = { bS, oC, oH, oW };
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW };
dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW };
pool_strides = { sH, sW };
pool_kernel = { kH, kW };
@ -45,14 +45,14 @@ namespace nd4j {
algorithm = poolingMode == 0 ? algorithm::pooling_max
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
: algorithm::pooling_avg_include_padding;
auto type = mkldnn::memory::data_type::f32;
auto format = isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any"
auto type = dnnl::memory::data_type::f32;
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
@ -60,9 +60,9 @@ namespace nd4j {
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
@ -70,9 +70,9 @@ namespace nd4j {
}
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
@ -84,12 +84,12 @@ namespace nd4j {
void getMKLDNNMemoryDescPool3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
mkldnn::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
mkldnn::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
pool_strides = { sD, sH, sW };
pool_kernel = { kD, kH, kW };
@ -101,14 +101,14 @@ namespace nd4j {
algorithm = poolingMode == 0 ? algorithm::pooling_max
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
: algorithm::pooling_avg_include_padding;
auto type = mkldnn::memory::data_type::f32;
auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc;
auto supposed_to_be_any_format = mkldnn::memory::format_tag::nCdhw8c; // doesn't work with "any"
auto type = dnnl::memory::data_type::f32;
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
@ -117,9 +117,9 @@ namespace nd4j {
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
@ -128,9 +128,9 @@ namespace nd4j {
}
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
@ -145,15 +145,15 @@ namespace nd4j {
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation) {
mkldnn::memory::dims conv_src_tz = { bS, iC, iH, iW };
mkldnn::memory::dims conv_weights_tz = { oC, iC, kH, kW };
mkldnn::memory::dims conv_bias_tz = { oC };
mkldnn::memory::dims conv_dst_tz = { bS, oC, oH, oW };
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW };
dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW };
dnnl::memory::dims conv_bias_tz = { oC };
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(iH, iW, oH, oW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
@ -169,14 +169,14 @@ namespace nd4j {
conv_padding_r = { (oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
auto type = mkldnn::memory::data_type::f32;
auto format = isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
auto formatw = mkldnn::memory::format_tag::hwio;
auto type = dnnl::memory::data_type::f32;
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
auto formatw = dnnl::memory::format_tag::hwio;
if (src != nullptr && conv_src_md != nullptr) {
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
@ -184,9 +184,9 @@ namespace nd4j {
}
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
@ -194,9 +194,9 @@ namespace nd4j {
}
if (weights != nullptr && conv_weights_md != nullptr) {
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
user_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = hwio"
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3];
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2];
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
@ -204,9 +204,9 @@ namespace nd4j {
}
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
user_diff_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = hwio"
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3];
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2];
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
@ -214,14 +214,14 @@ namespace nd4j {
}
if (bias != nullptr && conv_bias_md != nullptr) {
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::any);
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::x);
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
}
if (dst != nullptr && conv_dst_md != nullptr) {
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format_tag::any);
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
@ -233,15 +233,15 @@ namespace nd4j {
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation) {
mkldnn::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
mkldnn::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
mkldnn::memory::dims conv_bias_tz = { oC };
mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
dnnl::memory::dims conv_bias_tz = { oC };
dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
@ -251,14 +251,14 @@ namespace nd4j {
conv_padding_r = { pDmkl, pHmkl, pWmkl };
conv_dilation = { dDmkl, dHmkl, dWmkl };
auto type = mkldnn::memory::data_type::f32;
auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc;
auto formatw = mkldnn::memory::format_tag::dhwio;
auto type = dnnl::memory::data_type::f32;
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
auto formatw = dnnl::memory::format_tag::dhwio;
if (src != nullptr && conv_src_md != nullptr) {
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
@ -267,9 +267,9 @@ namespace nd4j {
}
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
@ -278,9 +278,9 @@ namespace nd4j {
}
if (weights != nullptr && conv_weights_md != nullptr) {
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
user_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = dhwio"
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4];
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3];
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
@ -289,9 +289,9 @@ namespace nd4j {
}
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
user_diff_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = dhwio"
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4];
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3];
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
@ -300,14 +300,14 @@ namespace nd4j {
}
if (bias != nullptr && conv_bias_md != nullptr) {
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::any);
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::x);
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
}
if (dst != nullptr && conv_dst_md != nullptr) {
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format_tag::any);
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
@ -318,23 +318,23 @@ namespace nd4j {
// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
// mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
// mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
// dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
// dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
// const Nd4jLong* shape = src->getShapeInfo();
// Nd4jLong rank = shape[0];
// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
// Nd4jLong dim2 = axis >= 2 ? 1 : 2;
// Nd4jLong dim3 = axis >= 3 ? 2 : 3;
// mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
// dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
// auto type = mkldnn::memory::data_type::f32;
// auto format = mkldnn::memory::format_tag::nchw;
// auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any"
// auto type = dnnl::memory::data_type::f32;
// auto format = dnnl::memory::format_tag::nchw;
// auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
// *batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
// user_src_md->data.format_kind = mkldnn_blocked; // overrides format
// *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
// user_src_md->data.format_kind = dnnl_blocked; // overrides format
// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
@ -342,9 +342,9 @@ namespace nd4j {
// }
// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
// *batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
// user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format
// *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
// user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
@ -352,9 +352,9 @@ namespace nd4j {
// }
// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
// *batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
// user_dst_md->data.format_kind = mkldnn_blocked; // overrides format
// *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
// user_dst_md->data.format_kind = dnnl_blocked; // overrides format
// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
@ -364,23 +364,23 @@ namespace nd4j {
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* lrn_src_md, mkldnn::memory::desc* lrn_diff_src_md, mkldnn::memory::desc* lrn_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
const Nd4jLong* shape = src->getShapeInfo();
long rank = shape[0];
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
long dim2 = axis >= 2 ? 1 : 2;
long dim3 = axis >= 3 ? 2 : 3;
mkldnn::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
auto type = mkldnn::memory::data_type::f32;
auto format = axis == 1 ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
auto type = dnnl::memory::data_type::f32;
auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
auto supposed_to_be_any_format = format; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
*lrn_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked;
*lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked;
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
@ -388,9 +388,9 @@ namespace nd4j {
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
*lrn_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked;
*lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked;
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
@ -398,9 +398,9 @@ namespace nd4j {
}
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
*lrn_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked;
*lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked;
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
@ -408,8 +408,8 @@ namespace nd4j {
}
}
mkldnn::engine& getEngine(void *ptr) {
auto eng = reinterpret_cast<mkldnn::engine*>(ptr);
dnnl::engine& getEngine(void *ptr) {
auto eng = reinterpret_cast<dnnl::engine*>(ptr);
return *eng;
}
}

View File

@ -23,7 +23,7 @@
#include <NativeOps.h>
#include <NDArray.h>
#include <mkldnn.hpp>
#include <dnnl.hpp>
#include <MKLDNNStream.h>
#include <graph/Context.h>
#include <ops/declarable/PlatformHelper.h>
@ -89,47 +89,47 @@ namespace nd4j{
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation);
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation);
void getMKLDNNMemoryDescConv3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation);
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation);
void getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r);
void getMKLDNNMemoryDescPool3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r);
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis);
dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* lrn_src_md, mkldnn::memory::desc* lrn_diff_src_md, mkldnn::memory::desc* lrn_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis);
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
mkldnn::engine& getEngine(void *ptr);
dnnl::engine& getEngine(void *ptr);
}
}

View File

@ -42,7 +42,7 @@ if ("${BUILD_MKLDNN}")
set(HAVE_MKLDNN 1)
add_definitions("-DHAVE_MKLDNN")
include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_SOURCE_DIR}/external/mklml_lnx_2019.0.3.20190220/include ${mkldnn_SOURCE_DIR})
set(MKLDNN mkldnn)
set(MKLDNN dnnl)
endif()
# Download and unpack flatbuffers at configure time