- MKL-DNN version upgrade to 1.1.x (#62)
- MKL-DNN namespace changes to match DNNL rename Signed-off-by: raver119 <raver119@gmail.com>master
parent
7898f3c0cc
commit
59e955cedc
|
@ -51,7 +51,7 @@ endif()
|
|||
|
||||
if(NOT CUDA_BLAS)
|
||||
# we need this definition to avoid global memory use within mkldnn
|
||||
add_definitions(-DMKLDNN_ENABLE_CONCURRENT_EXEC=true)
|
||||
add_definitions(-DDNNL_ENABLE_CONCURRENT_EXEC=true)
|
||||
|
||||
# there's a chance, we have no BLAS provided externally
|
||||
if ("${OPENBLAS_PATH}" STREQUAL "")
|
||||
|
@ -122,7 +122,7 @@ if(NOT CUDA_BLAS)
|
|||
if (${HELPERS_mkldnn})
|
||||
message("Going to pull & build mkldnn")
|
||||
set(HAVE_MKLDNN 1)
|
||||
set(MKLDNN_LIBRARY_TYPE "STATIC" CACHE STRING "Hack to enforce static mode" FORCE)
|
||||
set(DNNL_LIBRARY_TYPE "STATIC" CACHE STRING "Hack to enforce static mode" FORCE)
|
||||
|
||||
configure_file(./CMakeLists.txt.mkldnn.in mkldnn-download/CMakeLists.txt)
|
||||
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
|
||||
|
@ -146,7 +146,7 @@ if(NOT CUDA_BLAS)
|
|||
set(mkldnn_EXT_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src)
|
||||
set(MKLDNN_PATH "${mkldnn_SOURCE_DIR}")
|
||||
include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_EXT_DIR}/include ${mkldnn_SOURCE_DIR})
|
||||
set(MKLDNN mkldnn)
|
||||
set(MKLDNN dnnl)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
|
|
@ -5,11 +5,11 @@ project(mkldnn-download NONE)
|
|||
include(ExternalProject)
|
||||
ExternalProject_Add(mkldnn
|
||||
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
||||
GIT_TAG v1.0.4
|
||||
GIT_TAG v1.1.1
|
||||
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
||||
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
||||
CONFIGURE_COMMAND ""
|
||||
CMAKE_ARGS -DMKLDNN_USE_MKL=ML -DMKLDNN_LIBRARY_TYPE=STATIC -G \"Unix Makefiles\"
|
||||
CMAKE_ARGS -DDNNL_USE_MKL=ML -DDNNL_LIBRARY_TYPE=STATIC -G \"Unix Makefiles\"
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
TEST_COMMAND ""
|
||||
|
|
|
@ -30,14 +30,14 @@ thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
|
|||
#endif
|
||||
|
||||
#ifdef HAVE_MKLDNN
|
||||
#include <mkldnn.hpp>
|
||||
#include <dnnl.hpp>
|
||||
#endif
|
||||
|
||||
namespace nd4j {
|
||||
|
||||
LaunchContext::~LaunchContext() {
|
||||
#ifdef HAVE_MKLDNN
|
||||
delete reinterpret_cast<mkldnn::engine*>(_engine);
|
||||
delete reinterpret_cast<dnnl::engine*>(_engine);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -50,7 +50,7 @@ namespace nd4j {
|
|||
_deviceID = 0;
|
||||
|
||||
#ifdef HAVE_MKLDNN
|
||||
_engine = new mkldnn::engine(mkldnn::engine::kind::cpu, 0);
|
||||
_engine = new dnnl::engine(dnnl::engine::kind::cpu, 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -82,11 +82,11 @@ namespace nd4j {
|
|||
|
||||
auto poolingMode = PoolingType::AVG_POOL;
|
||||
|
||||
mkldnn_memory_desc_t empty;
|
||||
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
mkldnn::algorithm algorithm;
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
||||
|
@ -99,20 +99,20 @@ namespace nd4j {
|
|||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
auto pool_src_memory = user_src_memory;
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
auto pool_dst_memory = user_dst_memory;
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
|
||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
|
||||
{MKLDNN_ARG_DST, pool_dst_memory}});
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory}});
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -89,11 +89,11 @@ namespace nd4j {
|
|||
|
||||
auto poolingMode = PoolingType::AVG_POOL;
|
||||
|
||||
mkldnn_memory_desc_t empty;
|
||||
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
mkldnn::algorithm algorithm;
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
||||
|
@ -109,20 +109,20 @@ namespace nd4j {
|
|||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
|
||||
pool_kernel, pool_padding, pool_padding_r);
|
||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||
auto userB_src_memory = mkldnn::memory(user_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
|
||||
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||
auto poolB_src_memory = userB_src_memory;
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
auto poolB_dst_memory = userB_dst_memory;
|
||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||
}
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -86,11 +86,11 @@ namespace nd4j {
|
|||
|
||||
auto poolingMode = PoolingType::AVG_POOL;
|
||||
|
||||
mkldnn_memory_desc_t empty;
|
||||
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
mkldnn::algorithm algorithm;
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||
extraParam0, true,
|
||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
|
||||
|
@ -102,21 +102,21 @@ namespace nd4j {
|
|||
pool_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
auto pool_src_memory = user_src_memory;
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
auto pool_dst_memory = user_dst_memory;
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
|
||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
|
||||
{MKLDNN_ARG_DST, pool_dst_memory}});
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory}});
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -92,11 +92,11 @@ namespace nd4j {
|
|||
|
||||
auto poolingMode = PoolingType::AVG_POOL;
|
||||
|
||||
mkldnn_memory_desc_t empty;
|
||||
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
mkldnn::algorithm algorithm;
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||
extraParam0, true,
|
||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
|
||||
|
@ -111,24 +111,24 @@ namespace nd4j {
|
|||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
|
||||
pool_kernel, pool_padding, pool_padding_r);
|
||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||
auto userB_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
|
||||
auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||
auto poolB_src_memory = userB_src_memory;
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
auto poolB_dst_memory = userB_dst_memory;
|
||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||
}
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ namespace platforms {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, const float epsilon, NDArray* z) {
|
||||
|
||||
// unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any)
|
||||
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any)
|
||||
// also it gives wrong results for formats nhwc and ndhwc
|
||||
|
||||
// x -> 2D:nc, 4D:nchw, 5D:ncdhw
|
||||
|
@ -53,35 +53,35 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// input type
|
||||
mkldnn::memory::data_type type = mkldnn::memory::data_type::f32;
|
||||
dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||
|
||||
// indicate whether gamma or/and beta are given
|
||||
auto flags = mkldnn::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch
|
||||
auto flags = dnnl::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch
|
||||
if (weights != nullptr)
|
||||
flags |= mkldnn::normalization_flags::use_scale_shift;
|
||||
flags |= dnnl::normalization_flags::use_scale_shift;
|
||||
|
||||
mkldnn::memory::dims dims;
|
||||
mkldnn::memory::format_tag format;
|
||||
dnnl::memory::dims dims;
|
||||
dnnl::memory::format_tag format;
|
||||
|
||||
if(xRank == 2) {
|
||||
dims = {x->sizeAt(0), x->sizeAt(1)};
|
||||
format = mkldnn::memory::format_tag::nc;
|
||||
format = dnnl::memory::format_tag::nc;
|
||||
}
|
||||
else if(xRank == 4) {
|
||||
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
|
||||
format = mkldnn::memory::format_tag::nchw;
|
||||
format = dnnl::memory::format_tag::nchw;
|
||||
}
|
||||
else { // xRank = 5
|
||||
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
|
||||
format = mkldnn::memory::format_tag::ncdhw;
|
||||
format = dnnl::memory::format_tag::ncdhw;
|
||||
}
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// x
|
||||
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format);
|
||||
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format);
|
||||
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
|
||||
if(xRank > 2) {
|
||||
|
@ -92,9 +92,9 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
|
||||
|
||||
// z, output
|
||||
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(dims, type, format);
|
||||
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(dims, type, format);
|
||||
z_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = z->stridesOf()[0];
|
||||
z_user_md.data.format_desc.blocking.strides[1] = z->stridesOf()[1];
|
||||
if(xRank > 2) {
|
||||
|
@ -106,53 +106,53 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
|
||||
|
||||
// batchnorm forward description
|
||||
mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
|
||||
mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||
dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
|
||||
dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, mkldnn::memory> args;
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// provide memory and check whether reorder is required
|
||||
|
||||
// x
|
||||
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer());
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
|
||||
const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
auto x_mkl_mem = xReorder ? mkldnn::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[MKLDNN_ARG_SRC] = x_mkl_mem;
|
||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
|
||||
// z
|
||||
auto z_user_mem = mkldnn::memory(z_user_md, engine, z->getBuffer());
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
|
||||
const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? mkldnn::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
if (zReorder)
|
||||
mkldnn::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem);
|
||||
args[MKLDNN_ARG_DST] = z_mkl_mem;
|
||||
dnnl::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem);
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
|
||||
// mean
|
||||
auto mean_mkl_mem = mkldnn::memory(op_ff_prim_desc.mean_desc(), engine, mean->getBuffer());
|
||||
args[MKLDNN_ARG_MEAN] = mean_mkl_mem;
|
||||
auto mean_mkl_mem = dnnl::memory(op_ff_prim_desc.mean_desc(), engine, mean->getBuffer());
|
||||
args[DNNL_ARG_MEAN] = mean_mkl_mem;
|
||||
|
||||
// variance
|
||||
auto var_mkl_mem = mkldnn::memory(op_ff_prim_desc.variance_desc(), engine, variance->getBuffer());
|
||||
args[MKLDNN_ARG_VARIANCE] = var_mkl_mem;
|
||||
auto var_mkl_mem = dnnl::memory(op_ff_prim_desc.variance_desc(), engine, variance->getBuffer());
|
||||
args[DNNL_ARG_VARIANCE] = var_mkl_mem;
|
||||
|
||||
// gamma and beta (and their gradients) if they are present
|
||||
if(weights != nullptr) {
|
||||
|
||||
auto w_mkl_mem = mkldnn::memory(op_ff_prim_desc.weights_desc(), engine, weights->getBuffer());
|
||||
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
|
||||
auto w_mkl_mem = dnnl::memory(op_ff_prim_desc.weights_desc(), engine, weights->getBuffer());
|
||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||
}
|
||||
|
||||
// run calculations
|
||||
mkldnn::batch_normalization_forward(op_ff_prim_desc).execute(stream, args);
|
||||
dnnl::batch_normalization_forward(op_ff_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
@ -164,7 +164,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights,
|
||||
const float epsilon, NDArray* dLdI, NDArray* dLdW) {
|
||||
|
||||
// unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any)
|
||||
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any)
|
||||
// also it gives wrong results for formats nhwc and ndhwc
|
||||
|
||||
// x -> 2D:nc, 4D:nchw, 5D:ncdhw
|
||||
|
@ -180,35 +180,35 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// input type
|
||||
mkldnn::memory::data_type type = mkldnn::memory::data_type::f32;
|
||||
dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||
|
||||
// indicate whether gamma or/and beta are given
|
||||
auto flags = mkldnn::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch
|
||||
auto flags = dnnl::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch
|
||||
if (weights != nullptr)
|
||||
flags |= mkldnn::normalization_flags::use_scale_shift;
|
||||
flags |= dnnl::normalization_flags::use_scale_shift;
|
||||
|
||||
mkldnn::memory::dims dims;
|
||||
mkldnn::memory::format_tag format;
|
||||
dnnl::memory::dims dims;
|
||||
dnnl::memory::format_tag format;
|
||||
|
||||
if(xRank == 2) {
|
||||
dims = {x->sizeAt(0), x->sizeAt(1)};
|
||||
format = mkldnn::memory::format_tag::nc;
|
||||
format = dnnl::memory::format_tag::nc;
|
||||
}
|
||||
else if(xRank == 4) {
|
||||
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
|
||||
format = mkldnn::memory::format_tag::nchw;
|
||||
format = dnnl::memory::format_tag::nchw;
|
||||
}
|
||||
else { // xRank = 5
|
||||
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
|
||||
format = mkldnn::memory::format_tag::ncdhw;
|
||||
format = dnnl::memory::format_tag::ncdhw;
|
||||
}
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// x
|
||||
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format);
|
||||
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format);
|
||||
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
|
||||
if(xRank > 2) {
|
||||
|
@ -219,9 +219,9 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
|
||||
|
||||
// dLdO
|
||||
mkldnn::memory::desc dLdO_mkl_md = mkldnn::memory::desc(dims, type, format);
|
||||
mkldnn::memory::desc dLdO_user_md = mkldnn::memory::desc(dims, type, format);
|
||||
dLdO_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
|
||||
dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->stridesOf()[0];
|
||||
dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->stridesOf()[1];
|
||||
if(xRank > 2) {
|
||||
|
@ -232,9 +232,9 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->stridesOf()[4];
|
||||
|
||||
// dLdI
|
||||
mkldnn::memory::desc dLdI_mkl_md = mkldnn::memory::desc(dims, type, format);
|
||||
mkldnn::memory::desc dLdI_user_md = mkldnn::memory::desc(dims, type, format);
|
||||
dLdI_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
|
||||
dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->stridesOf()[0];
|
||||
dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->stridesOf()[1];
|
||||
if(xRank > 2) {
|
||||
|
@ -245,66 +245,66 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->stridesOf()[4];
|
||||
|
||||
// batchnorm forward description
|
||||
mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
|
||||
mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||
dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
|
||||
dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||
|
||||
// batchnorm backprop description
|
||||
mkldnn::batch_normalization_backward::desc op_bp_desc(mkldnn::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags);
|
||||
mkldnn::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc);
|
||||
dnnl::batch_normalization_backward::desc op_bp_desc(dnnl::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags);
|
||||
dnnl::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, mkldnn::memory> args;
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// provide memory and check whether reorder is required
|
||||
|
||||
// x
|
||||
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer());
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
|
||||
const bool xReorder = op_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
auto x_mkl_mem = xReorder ? mkldnn::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[MKLDNN_ARG_SRC] = x_mkl_mem;
|
||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
|
||||
// dLdO
|
||||
auto dLdO_user_mem = mkldnn::memory(dLdO_user_md, engine, dLdO->getBuffer());
|
||||
auto dLdO_user_mem = dnnl::memory(dLdO_user_md, engine, dLdO->getBuffer());
|
||||
const bool dLdOReorder = op_bp_prim_desc.diff_dst_desc() != dLdO_user_mem.get_desc();
|
||||
auto dLdO_mkl_mem = dLdOReorder ? mkldnn::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdO_user_mem;
|
||||
auto dLdO_mkl_mem = dLdOReorder ? dnnl::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdO_user_mem;
|
||||
if (dLdOReorder)
|
||||
mkldnn::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem);
|
||||
args[MKLDNN_ARG_DIFF_DST] = dLdO_mkl_mem;
|
||||
dnnl::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem);
|
||||
args[DNNL_ARG_DIFF_DST] = dLdO_mkl_mem;
|
||||
|
||||
// mean
|
||||
auto mean_mkl_mem = mkldnn::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());
|
||||
args[MKLDNN_ARG_MEAN] = mean_mkl_mem;
|
||||
auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());
|
||||
args[DNNL_ARG_MEAN] = mean_mkl_mem;
|
||||
|
||||
// variance
|
||||
auto var_mkl_mem = mkldnn::memory(op_bp_prim_desc.variance_desc(), engine, variance->getBuffer());
|
||||
args[MKLDNN_ARG_VARIANCE] = var_mkl_mem;
|
||||
auto var_mkl_mem = dnnl::memory(op_bp_prim_desc.variance_desc(), engine, variance->getBuffer());
|
||||
args[DNNL_ARG_VARIANCE] = var_mkl_mem;
|
||||
|
||||
// dLdI
|
||||
auto dLdI_user_mem = mkldnn::memory(dLdI_user_md, engine, dLdI->getBuffer());
|
||||
auto dLdI_user_mem = dnnl::memory(dLdI_user_md, engine, dLdI->getBuffer());
|
||||
const bool dLdIReorder = op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc();
|
||||
auto dLdI_mkl_mem = dLdIReorder ? mkldnn::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem;
|
||||
args[MKLDNN_ARG_DIFF_SRC] = dLdI_mkl_mem;
|
||||
auto dLdI_mkl_mem = dLdIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = dLdI_mkl_mem;
|
||||
|
||||
// gamma and beta (and their gradients) if they are present
|
||||
if(weights != nullptr) {
|
||||
|
||||
auto w_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, weights->getBuffer());
|
||||
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
|
||||
auto w_mkl_mem = dnnl::memory(op_bp_prim_desc.weights_desc(), engine, weights->getBuffer());
|
||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||
|
||||
auto dLdW_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->getBuffer());
|
||||
args[MKLDNN_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem;
|
||||
auto dLdW_mkl_mem = dnnl::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->getBuffer());
|
||||
args[DNNL_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem;
|
||||
}
|
||||
|
||||
// run calculations
|
||||
mkldnn::batch_normalization_backward(op_bp_prim_desc).execute(stream, args);
|
||||
dnnl::batch_normalization_backward(op_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (dLdIReorder)
|
||||
mkldnn::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem);
|
||||
dnnl::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
@ -532,37 +532,37 @@ PLATFORM_CHECK(batchnorm) {
|
|||
// weights({1, 2, 0, 0}).assign(0.0f);
|
||||
|
||||
// mkldnn_memory_desc_t empty;
|
||||
// mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty);
|
||||
// dnnl::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty);
|
||||
|
||||
// auto flag = mkldnn::normalization_flags::use_global_stats;
|
||||
// auto flag = dnnl::normalization_flags::use_global_stats;
|
||||
// if (applyScale || applyOffset)
|
||||
// flag |= mkldnn::normalization_flags::use_scale_shift;
|
||||
// flag |= dnnl::normalization_flags::use_scale_shift;
|
||||
|
||||
// mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output,
|
||||
// &batchnorm_src_md, nullptr, &batchnorm_dst_md,
|
||||
// &user_src_md, nullptr, &user_dst_md, axes[0]);
|
||||
|
||||
// auto batchnorm_desc = mkldnn::batch_normalization_forward::desc(mkldnn::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag);
|
||||
// auto batchnorm_desc = dnnl::batch_normalization_forward::desc(dnnl::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag);
|
||||
|
||||
// auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
// mkldnn::stream stream(engine);
|
||||
// auto batchnorm_prim_desc = mkldnn::batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
|
||||
// auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||
// auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||
// auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine,
|
||||
// dnnl::stream stream(engine);
|
||||
// auto batchnorm_prim_desc = dnnl::batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
|
||||
// auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
// auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
// auto batchnorm_mean_memory = dnnl::memory(batchnorm_prim_desc.mean_desc(), engine,
|
||||
// mean->buffer());
|
||||
// auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine,
|
||||
// auto batchnorm_variance_memory = dnnl::memory(batchnorm_prim_desc.variance_desc(), engine,
|
||||
// variance->buffer());
|
||||
// auto batchnorm_src_memory = user_src_memory;
|
||||
// mkldnn::memory m(batchnorm_src_md, engine);
|
||||
// dnnl::memory m(batchnorm_src_md, engine);
|
||||
// if (m.get_desc() != user_src_memory.get_desc()) {
|
||||
// batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine);
|
||||
// mkldnn::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory,
|
||||
// batchnorm_src_memory = dnnl::memory(batchnorm_src_md, engine);
|
||||
// dnnl::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory,
|
||||
// batchnorm_src_memory);
|
||||
// }
|
||||
// auto batchnorm_dst_memory = user_dst_memory;
|
||||
// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
// batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine);
|
||||
// batchnorm_dst_memory = dnnl::memory(batchnorm_prim_desc.dst_desc(), engine);
|
||||
// }
|
||||
// if (applyScale || applyOffset) {
|
||||
// if (gamma != nullptr) {
|
||||
|
@ -572,22 +572,22 @@ PLATFORM_CHECK(batchnorm) {
|
|||
// weights({1, 2, 0, 0}).assign(beta);
|
||||
// }
|
||||
|
||||
// auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer());
|
||||
// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
|
||||
// auto batchnorm_weights_memory = dnnl::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer());
|
||||
// dnnl::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
|
||||
// {{MKLDNN_ARG_SRC, batchnorm_src_memory},
|
||||
// {MKLDNN_ARG_MEAN, batchnorm_mean_memory},
|
||||
// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
|
||||
// {MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory},
|
||||
// {MKLDNN_ARG_DST, batchnorm_dst_memory}});
|
||||
// } else {
|
||||
// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
|
||||
// dnnl::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
|
||||
// {{MKLDNN_ARG_SRC, batchnorm_src_memory},
|
||||
// {MKLDNN_ARG_MEAN, batchnorm_mean_memory},
|
||||
// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
|
||||
// {MKLDNN_ARG_DST, batchnorm_dst_memory}});
|
||||
// }
|
||||
// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
// mkldnn::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory,
|
||||
// dnnl::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory,
|
||||
// user_dst_memory);
|
||||
// }
|
||||
// stream.wait();
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -47,12 +47,12 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
|
|||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
mkldnn_memory_desc_t empty;
|
||||
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
|
||||
empty);
|
||||
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
||||
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
||||
empty);
|
||||
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
|
||||
bias, output,
|
||||
|
@ -74,38 +74,38 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
|
|||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
conv_padding_r);
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
|
||||
auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||
auto user_weights_memory = mkldnn::memory(user_weights_md, engine,
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||
auto user_weights_memory = dnnl::memory(user_weights_md, engine,
|
||||
const_cast<NDArray *>(weights)->buffer());
|
||||
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
auto conv_src_memory = user_src_memory;
|
||||
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine);
|
||||
conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
|
||||
}
|
||||
auto conv_weights_memory = user_weights_memory;
|
||||
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
|
||||
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine);
|
||||
conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine);
|
||||
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
|
||||
conv_weights_memory);
|
||||
}
|
||||
auto conv_dst_memory = user_dst_memory;
|
||||
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine);
|
||||
conv_dst_memory = dnnl::memory(conv_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
if (bias != nullptr) {
|
||||
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine,
|
||||
auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine,
|
||||
const_cast<NDArray *>(bias)->buffer());
|
||||
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
|
||||
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
|
||||
{MKLDNN_ARG_BIAS, conv_bias_memory},
|
||||
{MKLDNN_ARG_DST, conv_dst_memory}});
|
||||
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
|
||||
{DNNL_ARG_WEIGHTS, conv_weights_memory},
|
||||
{DNNL_ARG_BIAS, conv_bias_memory},
|
||||
{DNNL_ARG_DST, conv_dst_memory}});
|
||||
} else {
|
||||
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
|
||||
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
|
||||
{MKLDNN_ARG_DST, conv_dst_memory}});
|
||||
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
|
||||
{DNNL_ARG_WEIGHTS, conv_weights_memory},
|
||||
{DNNL_ARG_DST, conv_dst_memory}});
|
||||
}
|
||||
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
|
||||
|
@ -198,12 +198,12 @@ PLATFORM_IMPL(conv2d_bp) {
|
|||
if (isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
mkldnn_memory_desc_t empty;
|
||||
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
||||
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
||||
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
|
||||
gradB, gradO,
|
||||
|
@ -235,47 +235,47 @@ PLATFORM_IMPL(conv2d_bp) {
|
|||
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
|
||||
conv_prim_desc);
|
||||
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
|
||||
auto userW_src_memory = dnnl::memory(user_src_md, engine,
|
||||
const_cast<NDArray *>(input)->buffer());
|
||||
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
|
||||
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
|
||||
auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
|
||||
auto userW_dst_memory = dnnl::memory(user_dst_md, engine,
|
||||
const_cast<NDArray *>(gradO)->buffer());
|
||||
|
||||
auto convW_src_memory = userW_src_memory;
|
||||
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
||||
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
|
||||
convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
|
||||
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
|
||||
convW_src_memory);
|
||||
}
|
||||
|
||||
auto convW_weights_memory = userW_weights_memory;
|
||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
|
||||
convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine);
|
||||
}
|
||||
|
||||
auto convW_dst_memory = userW_dst_memory;
|
||||
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
||||
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
|
||||
convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
|
||||
convW_dst_memory);
|
||||
}
|
||||
|
||||
if (gradB != nullptr) {
|
||||
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
|
||||
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine,
|
||||
gradB->buffer());
|
||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||
{{MKLDNN_ARG_SRC, convW_src_memory},
|
||||
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
||||
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
||||
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
|
||||
{{DNNL_ARG_SRC, convW_src_memory},
|
||||
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
||||
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
||||
{DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
|
||||
} else {
|
||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||
{{MKLDNN_ARG_SRC, convW_src_memory},
|
||||
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
||||
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
||||
{{DNNL_ARG_SRC, convW_src_memory},
|
||||
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
||||
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
||||
}
|
||||
|
||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||
|
@ -293,38 +293,38 @@ PLATFORM_IMPL(conv2d_bp) {
|
|||
conv_padding, conv_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
|
||||
conv_prim_desc);
|
||||
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
|
||||
auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userI_weights_memory = dnnl::memory(user_weights_md, engine,
|
||||
const_cast<NDArray *>(weights)->buffer());
|
||||
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
|
||||
auto userI_dst_memory = dnnl::memory(user_dst_md, engine,
|
||||
const_cast<NDArray *>(gradO)->buffer());
|
||||
|
||||
auto convI_src_memory = userI_src_memory;
|
||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
|
||||
convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
|
||||
auto convI_weights_memory = userI_weights_memory;
|
||||
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
||||
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
|
||||
convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
|
||||
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
|
||||
convI_weights_memory);
|
||||
}
|
||||
|
||||
auto convI_dst_memory = userI_dst_memory;
|
||||
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
||||
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
|
||||
convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
|
||||
convI_dst_memory);
|
||||
}
|
||||
|
||||
convolution_backward_data(convI_prim_desc).execute(stream,
|
||||
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
|
||||
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
|
||||
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
|
||||
{{DNNL_ARG_DIFF_DST, convI_dst_memory},
|
||||
{DNNL_ARG_WEIGHTS, convI_weights_memory},
|
||||
{DNNL_ARG_DIFF_SRC, convI_src_memory}});
|
||||
|
||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -83,12 +83,12 @@ PLATFORM_IMPL(conv3dnew) {
|
|||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
|
||||
mkldnn_memory_desc_t empty;
|
||||
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
|
||||
empty);
|
||||
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
||||
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
||||
empty);
|
||||
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
|
||||
isNCDHW,
|
||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights,
|
||||
|
@ -110,37 +110,37 @@ PLATFORM_IMPL(conv3dnew) {
|
|||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
conv_padding_r);
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
|
||||
auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||
auto user_weights_memory = mkldnn::memory(user_weights_md, engine,
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||
auto user_weights_memory = dnnl::memory(user_weights_md, engine,
|
||||
const_cast<NDArray *>(weights)->buffer());
|
||||
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
auto conv_src_memory = user_src_memory;
|
||||
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine);
|
||||
conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
|
||||
}
|
||||
auto conv_weights_memory = user_weights_memory;
|
||||
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
|
||||
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine);
|
||||
conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine);
|
||||
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
|
||||
conv_weights_memory);
|
||||
}
|
||||
auto conv_dst_memory = user_dst_memory;
|
||||
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine);
|
||||
conv_dst_memory = dnnl::memory(conv_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
if (bias != nullptr) {
|
||||
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine, bias->buffer());
|
||||
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
|
||||
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
|
||||
{MKLDNN_ARG_BIAS, conv_bias_memory},
|
||||
{MKLDNN_ARG_DST, conv_dst_memory}});
|
||||
auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->buffer());
|
||||
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
|
||||
{DNNL_ARG_WEIGHTS, conv_weights_memory},
|
||||
{DNNL_ARG_BIAS, conv_bias_memory},
|
||||
{DNNL_ARG_DST, conv_dst_memory}});
|
||||
} else {
|
||||
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
|
||||
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
|
||||
{MKLDNN_ARG_DST, conv_dst_memory}});
|
||||
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
|
||||
{DNNL_ARG_WEIGHTS, conv_weights_memory},
|
||||
{DNNL_ARG_DST, conv_dst_memory}});
|
||||
}
|
||||
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
|
||||
|
@ -235,12 +235,12 @@ PLATFORM_IMPL(conv3dnew_bp) {
|
|||
oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
|
||||
mkldnn_memory_desc_t empty;
|
||||
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
||||
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
||||
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
|
||||
isNDHWC,
|
||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights,
|
||||
|
@ -273,47 +273,47 @@ PLATFORM_IMPL(conv3dnew_bp) {
|
|||
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
|
||||
conv_prim_desc);
|
||||
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
|
||||
auto userW_src_memory = dnnl::memory(user_src_md, engine,
|
||||
const_cast<NDArray *>(input)->buffer());
|
||||
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
|
||||
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
|
||||
auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
|
||||
auto userW_dst_memory = dnnl::memory(user_dst_md, engine,
|
||||
const_cast<NDArray *>(gradO)->buffer());
|
||||
|
||||
auto convW_src_memory = userW_src_memory;
|
||||
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
||||
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
|
||||
convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
|
||||
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
|
||||
convW_src_memory);
|
||||
}
|
||||
|
||||
auto convW_weights_memory = userW_weights_memory;
|
||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
|
||||
convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine);
|
||||
}
|
||||
|
||||
auto convW_dst_memory = userW_dst_memory;
|
||||
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
||||
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
|
||||
convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
|
||||
convW_dst_memory);
|
||||
}
|
||||
|
||||
if (gradB != nullptr) {
|
||||
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
|
||||
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine,
|
||||
gradB->buffer());
|
||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||
{{MKLDNN_ARG_SRC, convW_src_memory},
|
||||
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
||||
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
||||
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
|
||||
{{DNNL_ARG_SRC, convW_src_memory},
|
||||
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
||||
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
||||
{DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
|
||||
} else {
|
||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||
{{MKLDNN_ARG_SRC, convW_src_memory},
|
||||
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
||||
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
||||
{{DNNL_ARG_SRC, convW_src_memory},
|
||||
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
||||
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
||||
}
|
||||
|
||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||
|
@ -330,38 +330,38 @@ PLATFORM_IMPL(conv3dnew_bp) {
|
|||
conv_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
|
||||
conv_prim_desc);
|
||||
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
|
||||
auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userI_weights_memory = dnnl::memory(user_weights_md, engine,
|
||||
const_cast<NDArray *>(weights)->buffer());
|
||||
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
|
||||
auto userI_dst_memory = dnnl::memory(user_dst_md, engine,
|
||||
const_cast<NDArray *>(gradO)->buffer());
|
||||
|
||||
auto convI_src_memory = userI_src_memory;
|
||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
|
||||
convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
|
||||
auto convI_weights_memory = userI_weights_memory;
|
||||
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
||||
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
|
||||
convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
|
||||
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
|
||||
convI_weights_memory);
|
||||
}
|
||||
|
||||
auto convI_dst_memory = userI_dst_memory;
|
||||
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
||||
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
|
||||
convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
|
||||
convI_dst_memory);
|
||||
}
|
||||
|
||||
convolution_backward_data(convI_prim_desc).execute(stream,
|
||||
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
|
||||
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
|
||||
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
|
||||
{{DNNL_ARG_DIFF_DST, convI_dst_memory},
|
||||
{DNNL_ARG_WEIGHTS, convI_weights_memory},
|
||||
{DNNL_ARG_DIFF_SRC, convI_src_memory}});
|
||||
|
||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
|
||||
|
|
|
@ -49,77 +49,77 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
|
||||
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
|
||||
|
||||
mkldnn::memory::dims strides = { sH, sW };
|
||||
mkldnn::memory::dims padding = { pH, pW };
|
||||
mkldnn::memory::dims padding_r = { pHmkl, pWmkl };
|
||||
mkldnn::memory::dims dilation = { dHmkl, dWmkl };
|
||||
dnnl::memory::dims strides = { sH, sW };
|
||||
dnnl::memory::dims padding = { pH, pW };
|
||||
dnnl::memory::dims padding_r = { pHmkl, pWmkl };
|
||||
dnnl::memory::dims dilation = { dHmkl, dWmkl };
|
||||
|
||||
// input type
|
||||
mkldnn::memory::data_type xType;
|
||||
dnnl::memory::data_type xType;
|
||||
if(input->dataType() == DataType::FLOAT32)
|
||||
xType = mkldnn::memory::data_type::f32;
|
||||
xType = dnnl::memory::data_type::f32;
|
||||
else if(input->dataType() == DataType::HALF)
|
||||
xType = mkldnn::memory::data_type::f16;
|
||||
xType = dnnl::memory::data_type::f16;
|
||||
else if(input->dataType() == DataType::UINT8)
|
||||
xType = mkldnn::memory::data_type::u8;
|
||||
xType = dnnl::memory::data_type::u8;
|
||||
else
|
||||
xType = mkldnn::memory::data_type::s8;
|
||||
xType = dnnl::memory::data_type::s8;
|
||||
|
||||
// weights type
|
||||
mkldnn::memory::data_type wType = xType;
|
||||
if(xType == mkldnn::memory::data_type::u8)
|
||||
wType = mkldnn::memory::data_type::s8;
|
||||
dnnl::memory::data_type wType = xType;
|
||||
if(xType == dnnl::memory::data_type::u8)
|
||||
wType = dnnl::memory::data_type::s8;
|
||||
|
||||
// output and bias type (have the same types)
|
||||
mkldnn::memory::data_type zType;
|
||||
dnnl::memory::data_type zType;
|
||||
if(output->dataType() == DataType::FLOAT32)
|
||||
zType = mkldnn::memory::data_type::f32;
|
||||
zType = dnnl::memory::data_type::f32;
|
||||
else if(output->dataType() == DataType::HALF)
|
||||
zType = mkldnn::memory::data_type::f16;
|
||||
zType = dnnl::memory::data_type::f16;
|
||||
else if(output->dataType() == DataType::UINT8)
|
||||
zType = mkldnn::memory::data_type::u8;
|
||||
zType = dnnl::memory::data_type::u8;
|
||||
else if(output->dataType() == DataType::INT8)
|
||||
zType = mkldnn::memory::data_type::s8;
|
||||
zType = dnnl::memory::data_type::s8;
|
||||
else
|
||||
zType = mkldnn::memory::data_type::s32;
|
||||
zType = dnnl::memory::data_type::s32;
|
||||
|
||||
|
||||
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
|
||||
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw;
|
||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||
|
||||
mkldnn::memory::dims xDims = {bS, iC, iH, iW};
|
||||
mkldnn::memory::dims wDims = {oC, iC, kH, kW};
|
||||
mkldnn::memory::dims zDims = {bS, oC, oH, oW};
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kH, kW};
|
||||
dnnl::memory::dims zDims = {bS, oC, oH, oW};
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
|
||||
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
|
||||
|
||||
// weights
|
||||
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
||||
|
||||
// bias
|
||||
mkldnn::memory::desc b_mkl_md;
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
if(bias != nullptr)
|
||||
b_mkl_md = mkldnn::memory::desc({oC}, zType, mkldnn::memory::format_tag::x);
|
||||
b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x);
|
||||
|
||||
// output
|
||||
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(zDims, zType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(zDims, zType, xFormat);
|
||||
z_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
|
||||
z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
|
||||
z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
|
||||
|
@ -128,51 +128,51 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// operation primitive description
|
||||
mkldnn::deconvolution_forward::desc op_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct,
|
||||
dnnl::deconvolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct,
|
||||
x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
|
||||
mkldnn::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine);
|
||||
dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, mkldnn::memory> args;
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
auto x_mkl_mem = xReorder ? mkldnn::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[MKLDNN_ARG_SRC] = x_mkl_mem;
|
||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
|
||||
// weights
|
||||
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
|
||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||
auto w_mkl_mem = wReorder ? mkldnn::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||
if (wReorder)
|
||||
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
|
||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||
|
||||
// bias
|
||||
if(bias != nullptr) {
|
||||
auto b_mkl_mem = mkldnn::memory(b_mkl_md, engine, bias->getBuffer());
|
||||
args[MKLDNN_ARG_BIAS] = b_mkl_mem;
|
||||
auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer());
|
||||
args[DNNL_ARG_BIAS] = b_mkl_mem;
|
||||
}
|
||||
|
||||
// output
|
||||
auto z_user_mem = mkldnn::memory(z_user_md, engine, output->getBuffer());
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? mkldnn::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[MKLDNN_ARG_DST] = z_mkl_mem;
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
|
||||
// run calculations
|
||||
mkldnn::deconvolution_forward(op_prim_desc).execute(stream, args);
|
||||
dnnl::deconvolution_forward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
@ -196,157 +196,157 @@ static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
|
||||
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
|
||||
|
||||
mkldnn::memory::dims strides = { sH, sW };
|
||||
mkldnn::memory::dims padding = { pH, pW };
|
||||
mkldnn::memory::dims padding_r = { pHmkl, pWmkl };
|
||||
mkldnn::memory::dims dilation = { dHmkl, dWmkl };
|
||||
dnnl::memory::dims strides = { sH, sW };
|
||||
dnnl::memory::dims padding = { pH, pW };
|
||||
dnnl::memory::dims padding_r = { pHmkl, pWmkl };
|
||||
dnnl::memory::dims dilation = { dHmkl, dWmkl };
|
||||
// input type
|
||||
mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// weights type
|
||||
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// gradO type
|
||||
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// gradI type
|
||||
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// gradW type
|
||||
mkldnn::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// gradB type
|
||||
mkldnn::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16) : mkldnn::memory::data_type::f32;
|
||||
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
||||
|
||||
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
|
||||
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw;
|
||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||
|
||||
mkldnn::memory::dims xDims = {bS, iC, iH, iW};
|
||||
mkldnn::memory::dims wDims = {oC, iC, kH, kW};
|
||||
mkldnn::memory::dims zDims = {bS, oC, oH, oW};
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kH, kW};
|
||||
dnnl::memory::dims zDims = {bS, oC, oH, oW};
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
|
||||
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
|
||||
|
||||
// weights
|
||||
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
||||
|
||||
// gradO
|
||||
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat);
|
||||
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
|
||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
|
||||
|
||||
// gradI
|
||||
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat);
|
||||
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
|
||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
|
||||
|
||||
// gradW
|
||||
mkldnn::memory::desc gradW_mkl_md = mkldnn::memory::desc(wDims, gradWType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc gradW_user_md = mkldnn::memory::desc(wDims, gradWType, wFormat);
|
||||
gradW_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3];
|
||||
|
||||
// gradB
|
||||
mkldnn::memory::desc gradB_mkl_md;
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
if(gradB != nullptr)
|
||||
gradB_mkl_md = mkldnn::memory::desc({oC}, gradBType, mkldnn::memory::format_tag::x);
|
||||
gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x);
|
||||
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// forward primitive description
|
||||
mkldnn::deconvolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
mkldnn::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||
dnnl::deconvolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||
|
||||
// backward data primitive description
|
||||
mkldnn::deconvolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
mkldnn::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
|
||||
dnnl::deconvolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
|
||||
|
||||
// backward weights primitive description
|
||||
mkldnn::deconvolution_backward_weights::desc op_weights_bp_desc(mkldnn::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
mkldnn::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
|
||||
dnnl::deconvolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, mkldnn::memory> args;
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
auto x_mkl_mem = xReorder ? mkldnn::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[MKLDNN_ARG_SRC] = x_mkl_mem;
|
||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
|
||||
// weights
|
||||
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
|
||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||
if (wReorder)
|
||||
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
|
||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
if (gradOReorder)
|
||||
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
|
||||
// gradW
|
||||
auto gradW_user_mem = mkldnn::memory(gradW_user_md, engine, gradW->getBuffer());
|
||||
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer());
|
||||
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
||||
auto gradW_mkl_mem = gradWReorder ? mkldnn::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||
args[MKLDNN_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||
|
||||
// gradB
|
||||
if(gradB != nullptr) {
|
||||
auto gradB_mkl_mem = mkldnn::memory(gradB_mkl_md, engine, gradB->getBuffer());
|
||||
args[MKLDNN_ARG_DIFF_BIAS] = gradB_mkl_mem;
|
||||
auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer());
|
||||
args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem;
|
||||
}
|
||||
|
||||
// run backward data calculations
|
||||
mkldnn::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||
dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// run backward weights calculations
|
||||
mkldnn::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder gradI if necessary
|
||||
if (gradIReorder)
|
||||
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
if (gradWReorder)
|
||||
mkldnn::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
|
|
@ -39,52 +39,52 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
|||
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC]
|
||||
// gradO [bS, oH, oW, oC]
|
||||
|
||||
mkldnn::memory::dims strides = { sH, sW };
|
||||
mkldnn::memory::dims dilation = { dH - 1, dW - 1 };
|
||||
mkldnn::memory::dims padding = { pH, pW };
|
||||
mkldnn::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||
dnnl::memory::dims strides = { sH, sW };
|
||||
dnnl::memory::dims dilation = { dH - 1, dW - 1 };
|
||||
dnnl::memory::dims padding = { pH, pW };
|
||||
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||
|
||||
// weights type
|
||||
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// gradO type
|
||||
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// gradI type
|
||||
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
|
||||
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
|
||||
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw;
|
||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||
|
||||
mkldnn::memory::dims xDims = {bS, iC, iH, iW};
|
||||
mkldnn::memory::dims wDims = {oC, iC, kH, kW};
|
||||
mkldnn::memory::dims zDims = {bS, oC, oH, oW};
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kH, kW};
|
||||
dnnl::memory::dims zDims = {bS, oC, oH, oW};
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, gradOType, mkldnn::memory::format_tag::any);
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, gradOType, dnnl::memory::format_tag::any);
|
||||
|
||||
// weights
|
||||
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
||||
|
||||
// gradO
|
||||
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat);
|
||||
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
|
||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
|
||||
|
||||
// gradI
|
||||
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat);
|
||||
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
|
||||
|
@ -94,48 +94,48 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
|||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// forward primitive description
|
||||
mkldnn::convolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
mkldnn::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||
dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||
|
||||
// backward data primitive description
|
||||
mkldnn::convolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
mkldnn::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
|
||||
dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, mkldnn::memory> args;
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// weights
|
||||
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
|
||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||
if (wReorder)
|
||||
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
|
||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
if (gradOReorder)
|
||||
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
|
||||
// run backward data calculations
|
||||
mkldnn::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder gradI if necessary
|
||||
if (gradIReorder)
|
||||
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
|
|
@ -50,54 +50,54 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
|
||||
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
|
||||
|
||||
mkldnn::memory::dims strides = { sD, sH, sW };
|
||||
mkldnn::memory::dims padding = { pD, pH, pW };
|
||||
mkldnn::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
|
||||
mkldnn::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
|
||||
dnnl::memory::dims strides = { sD, sH, sW };
|
||||
dnnl::memory::dims padding = { pD, pH, pW };
|
||||
dnnl::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
|
||||
dnnl::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
|
||||
|
||||
// input type
|
||||
mkldnn::memory::data_type xType;
|
||||
dnnl::memory::data_type xType;
|
||||
if(input->dataType() == DataType::FLOAT32)
|
||||
xType = mkldnn::memory::data_type::f32;
|
||||
xType = dnnl::memory::data_type::f32;
|
||||
else if(input->dataType() == DataType::HALF)
|
||||
xType = mkldnn::memory::data_type::f16;
|
||||
xType = dnnl::memory::data_type::f16;
|
||||
else if(input->dataType() == DataType::UINT8)
|
||||
xType = mkldnn::memory::data_type::u8;
|
||||
xType = dnnl::memory::data_type::u8;
|
||||
else
|
||||
xType = mkldnn::memory::data_type::s8;
|
||||
xType = dnnl::memory::data_type::s8;
|
||||
|
||||
// weights type
|
||||
mkldnn::memory::data_type wType = xType;
|
||||
if(xType == mkldnn::memory::data_type::u8)
|
||||
wType = mkldnn::memory::data_type::s8;
|
||||
dnnl::memory::data_type wType = xType;
|
||||
if(xType == dnnl::memory::data_type::u8)
|
||||
wType = dnnl::memory::data_type::s8;
|
||||
|
||||
// output and bias type (have the same types)
|
||||
mkldnn::memory::data_type zType;
|
||||
dnnl::memory::data_type zType;
|
||||
if(output->dataType() == DataType::FLOAT32)
|
||||
zType = mkldnn::memory::data_type::f32;
|
||||
zType = dnnl::memory::data_type::f32;
|
||||
else if(output->dataType() == DataType::HALF)
|
||||
zType = mkldnn::memory::data_type::f16;
|
||||
zType = dnnl::memory::data_type::f16;
|
||||
else if(output->dataType() == DataType::UINT8)
|
||||
zType = mkldnn::memory::data_type::u8;
|
||||
zType = dnnl::memory::data_type::u8;
|
||||
else if(output->dataType() == DataType::INT8)
|
||||
zType = mkldnn::memory::data_type::s8;
|
||||
zType = dnnl::memory::data_type::s8;
|
||||
else
|
||||
zType = mkldnn::memory::data_type::s32;
|
||||
zType = dnnl::memory::data_type::s32;
|
||||
|
||||
|
||||
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::ncdhw;
|
||||
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oidhw;
|
||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
||||
|
||||
mkldnn::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||
mkldnn::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||
mkldnn::memory::dims zDims = {bS, oC, oD, oH, oW};
|
||||
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||
dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
|
||||
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
||||
|
@ -105,9 +105,9 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
|
||||
|
||||
// weights
|
||||
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||
|
@ -115,14 +115,14 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
|
||||
|
||||
// bias
|
||||
mkldnn::memory::desc b_mkl_md;
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
if(bias != nullptr)
|
||||
b_mkl_md = mkldnn::memory::desc({oC}, zType, mkldnn::memory::format_tag::x);
|
||||
b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x);
|
||||
|
||||
// output
|
||||
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(zDims, zType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(zDims, zType, xFormat);
|
||||
z_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
|
||||
z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
|
||||
z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
|
||||
|
@ -132,51 +132,51 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// operation primitive description
|
||||
mkldnn::deconvolution_forward::desc op_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct,
|
||||
dnnl::deconvolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct,
|
||||
x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
|
||||
mkldnn::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine);
|
||||
dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, mkldnn::memory> args;
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
auto x_mkl_mem = xReorder ? mkldnn::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[MKLDNN_ARG_SRC] = x_mkl_mem;
|
||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
|
||||
// weights
|
||||
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
|
||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||
auto w_mkl_mem = wReorder ? mkldnn::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||
if (wReorder)
|
||||
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
|
||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||
|
||||
// bias
|
||||
if(bias != nullptr) {
|
||||
auto b_mkl_mem = mkldnn::memory(b_mkl_md, engine, bias->getBuffer());
|
||||
args[MKLDNN_ARG_BIAS] = b_mkl_mem;
|
||||
auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer());
|
||||
args[DNNL_ARG_BIAS] = b_mkl_mem;
|
||||
}
|
||||
|
||||
// output
|
||||
auto z_user_mem = mkldnn::memory(z_user_md, engine, output->getBuffer());
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? mkldnn::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[MKLDNN_ARG_DST] = z_mkl_mem;
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
|
||||
// run calculations
|
||||
mkldnn::deconvolution_forward(op_prim_desc).execute(stream, args);
|
||||
dnnl::deconvolution_forward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
@ -200,37 +200,37 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
|
||||
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
|
||||
|
||||
mkldnn::memory::dims strides = { sD, sH, sW };
|
||||
mkldnn::memory::dims padding = { pD, pH, pW };
|
||||
mkldnn::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
|
||||
mkldnn::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
|
||||
dnnl::memory::dims strides = { sD, sH, sW };
|
||||
dnnl::memory::dims padding = { pD, pH, pW };
|
||||
dnnl::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
|
||||
dnnl::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
|
||||
|
||||
// input type
|
||||
mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// weights type
|
||||
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// gradO type
|
||||
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// gradI type
|
||||
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// gradW type
|
||||
mkldnn::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// gradB type
|
||||
mkldnn::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16) : mkldnn::memory::data_type::f32;
|
||||
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
||||
|
||||
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::ncdhw; // isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc;
|
||||
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oidhw;
|
||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw; // isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
||||
|
||||
mkldnn::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||
mkldnn::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||
mkldnn::memory::dims zDims = {bS, oC, oD, oH, oW};
|
||||
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||
dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
|
||||
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
||||
|
@ -238,9 +238,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
|
||||
|
||||
// weights
|
||||
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||
|
@ -248,9 +248,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
|
||||
|
||||
// gradO
|
||||
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat);
|
||||
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
|
||||
|
@ -258,9 +258,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->stridesOf()[4];
|
||||
|
||||
// gradI
|
||||
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any);
|
||||
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat);
|
||||
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
|
||||
|
@ -268,9 +268,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->stridesOf()[4];
|
||||
|
||||
// gradW
|
||||
mkldnn::memory::desc gradW_mkl_md = mkldnn::memory::desc(wDims, gradWType, wFormat);
|
||||
mkldnn::memory::desc gradW_user_md = mkldnn::memory::desc(wDims, gradWType, wFormat);
|
||||
gradW_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
|
||||
|
@ -278,85 +278,85 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->stridesOf()[4];
|
||||
|
||||
// gradB
|
||||
mkldnn::memory::desc gradB_mkl_md;
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
if(gradB != nullptr)
|
||||
gradB_mkl_md = mkldnn::memory::desc({oC}, gradBType, mkldnn::memory::format_tag::x);
|
||||
gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x);
|
||||
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// forward primitive description
|
||||
mkldnn::deconvolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
mkldnn::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||
dnnl::deconvolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||
|
||||
// backward data primitive description
|
||||
mkldnn::deconvolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
mkldnn::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
|
||||
dnnl::deconvolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
|
||||
|
||||
// backward weights primitive description
|
||||
mkldnn::deconvolution_backward_weights::desc op_weights_bp_desc(mkldnn::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
mkldnn::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
|
||||
dnnl::deconvolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, mkldnn::memory> args;
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
auto x_mkl_mem = xReorder ? mkldnn::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[MKLDNN_ARG_SRC] = x_mkl_mem;
|
||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
|
||||
// weights
|
||||
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
|
||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||
if (wReorder)
|
||||
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
|
||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
if (gradOReorder)
|
||||
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
|
||||
// gradW
|
||||
auto gradW_user_mem = mkldnn::memory(gradW_user_md, engine, gradW->getBuffer());
|
||||
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer());
|
||||
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
||||
auto gradW_mkl_mem = gradWReorder ? mkldnn::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||
args[MKLDNN_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||
|
||||
// gradB
|
||||
if(gradB != nullptr) {
|
||||
auto gradB_mkl_mem = mkldnn::memory(gradB_mkl_md, engine, gradB->getBuffer());
|
||||
args[MKLDNN_ARG_DIFF_BIAS] = gradB_mkl_mem;
|
||||
auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer());
|
||||
args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem;
|
||||
}
|
||||
|
||||
// run backward data calculations
|
||||
mkldnn::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||
dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// run backward weights calculations
|
||||
mkldnn::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder gradI if necessary
|
||||
if (gradIReorder)
|
||||
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
if (gradWReorder)
|
||||
mkldnn::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -44,8 +44,8 @@ namespace nd4j {
|
|||
double bias = T_ARG(0);
|
||||
int depth = INT_ARG(0);
|
||||
|
||||
mkldnn_memory_desc_t empty;
|
||||
mkldnn::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty);
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty);
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescLrn(input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md,
|
||||
&user_src_md, nullptr, &user_dst_md, input->rankOf() - 1);
|
||||
|
@ -54,24 +54,24 @@ namespace nd4j {
|
|||
lrn_src_md, (2 * depth + 1), alpha * (2 * depth + 1), beta, bias);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine);
|
||||
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
|
||||
auto lrn_src_memory = user_src_memory;
|
||||
if (lrn_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
lrn_src_memory = mkldnn::memory(lrn_prim_desc.src_desc(), engine);
|
||||
lrn_src_memory = dnnl::memory(lrn_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, lrn_src_memory).execute(stream, user_src_memory, lrn_src_memory);
|
||||
}
|
||||
|
||||
auto lrn_dst_memory = user_dst_memory;
|
||||
if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
lrn_dst_memory = mkldnn::memory(lrn_prim_desc.dst_desc(), engine);
|
||||
lrn_dst_memory = dnnl::memory(lrn_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
|
||||
lrn_forward(lrn_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, lrn_src_memory},
|
||||
{MKLDNN_ARG_DST, lrn_dst_memory}});
|
||||
lrn_forward(lrn_prim_desc).execute(stream, {{DNNL_ARG_SRC, lrn_src_memory},
|
||||
{DNNL_ARG_DST, lrn_dst_memory}});
|
||||
|
||||
if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(lrn_dst_memory, user_dst_memory).execute(stream, lrn_dst_memory, user_dst_memory);
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <ops/declarable/OpRegistrator.h>
|
||||
#include "mkldnnUtils.h"
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -132,52 +132,52 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
|||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
mkldnn::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md,
|
||||
dnnl::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md,
|
||||
x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md;
|
||||
|
||||
// input type
|
||||
mkldnn::memory::data_type xType;
|
||||
dnnl::memory::data_type xType;
|
||||
if(x->dataType() == DataType::FLOAT32)
|
||||
xType = mkldnn::memory::data_type::f32;
|
||||
xType = dnnl::memory::data_type::f32;
|
||||
else if(x->dataType() == DataType::HALF)
|
||||
xType = mkldnn::memory::data_type::f16;
|
||||
xType = dnnl::memory::data_type::f16;
|
||||
else
|
||||
xType = mkldnn::memory::data_type::u8;
|
||||
xType = dnnl::memory::data_type::u8;
|
||||
|
||||
// weights type
|
||||
mkldnn::memory::data_type wType = xType;
|
||||
if(xType == mkldnn::memory::data_type::u8)
|
||||
wType = mkldnn::memory::data_type::s8;
|
||||
dnnl::memory::data_type wType = xType;
|
||||
if(xType == dnnl::memory::data_type::u8)
|
||||
wType = dnnl::memory::data_type::s8;
|
||||
|
||||
// bias type
|
||||
mkldnn::memory::data_type bType = xType;
|
||||
if(xType == mkldnn::memory::data_type::u8)
|
||||
bType = mkldnn::memory::data_type::f32;
|
||||
dnnl::memory::data_type bType = xType;
|
||||
if(xType == dnnl::memory::data_type::u8)
|
||||
bType = dnnl::memory::data_type::f32;
|
||||
|
||||
// output type
|
||||
mkldnn::memory::data_type hType;
|
||||
dnnl::memory::data_type hType;
|
||||
if(h->dataType() == DataType::FLOAT32)
|
||||
hType = mkldnn::memory::data_type::f32;
|
||||
hType = dnnl::memory::data_type::f32;
|
||||
else if(h->dataType() == DataType::HALF)
|
||||
hType = mkldnn::memory::data_type::f16;
|
||||
hType = dnnl::memory::data_type::f16;
|
||||
else
|
||||
hType = mkldnn::memory::data_type::u8;
|
||||
hType = dnnl::memory::data_type::u8;
|
||||
|
||||
|
||||
// memory descriptors for arrays
|
||||
// x
|
||||
x_lstm_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::any);
|
||||
// x_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, nIn}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, nIn}, type, mkldnn::memory::format_tag::ntc);
|
||||
x_user_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::tnc);
|
||||
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
x_lstm_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any);
|
||||
// x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, dnnl::memory::format_tag::ntc);
|
||||
x_user_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
|
||||
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
|
||||
|
||||
// wx
|
||||
wx_lstm_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::any);
|
||||
wx_user_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::ldigo);
|
||||
wx_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
wx_lstm_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::any);
|
||||
wx_user_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::ldigo);
|
||||
wx_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0];
|
||||
wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1];
|
||||
wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2];
|
||||
|
@ -185,9 +185,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
|||
wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4];
|
||||
|
||||
// wr
|
||||
wr_lstm_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::any);
|
||||
wr_user_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::ldigo);
|
||||
wr_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
wr_lstm_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::any);
|
||||
wr_user_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::ldigo);
|
||||
wr_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0];
|
||||
wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1];
|
||||
wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2];
|
||||
|
@ -195,19 +195,19 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
|||
wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4];
|
||||
|
||||
// h
|
||||
h_lstm_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::any);
|
||||
// h_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, hDirDim*nOut}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, hDirDim*nOut}, type, mkldnn::memory::format_tag::ntc);
|
||||
h_user_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::tnc);
|
||||
h_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::any);
|
||||
// h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, hDirDim*nOut}, type, dnnl::memory::format_tag::ntc);
|
||||
h_user_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::tnc);
|
||||
h_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0];
|
||||
h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1];
|
||||
h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2];
|
||||
|
||||
// b
|
||||
if(b) {
|
||||
b_lstm_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::any);
|
||||
b_user_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::ldgo);
|
||||
b_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
b_lstm_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::any);
|
||||
b_user_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::ldgo);
|
||||
b_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0];
|
||||
b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1];
|
||||
b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2];
|
||||
|
@ -216,9 +216,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
|||
|
||||
// hI
|
||||
if(hI) {
|
||||
hI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any);
|
||||
hI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc);
|
||||
hI_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
hI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any);
|
||||
hI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc);
|
||||
hI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0];
|
||||
hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1];
|
||||
hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2];
|
||||
|
@ -227,9 +227,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
|||
|
||||
// cI
|
||||
if(cI) {
|
||||
cI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any);
|
||||
cI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc);
|
||||
cI_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
cI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any);
|
||||
cI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc);
|
||||
cI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0];
|
||||
cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1];
|
||||
cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2];
|
||||
|
@ -238,9 +238,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
|||
|
||||
// hL
|
||||
if(hL) {
|
||||
hL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::any);
|
||||
hL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc);
|
||||
hL_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
hL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::any);
|
||||
hL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
|
||||
hL_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0];
|
||||
hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1];
|
||||
hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2];
|
||||
|
@ -248,9 +248,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
|||
}
|
||||
|
||||
if(cL) {
|
||||
cL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc);
|
||||
cL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc);
|
||||
cL_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
cL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
|
||||
cL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
|
||||
cL_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0];
|
||||
cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1];
|
||||
cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2];
|
||||
|
@ -262,92 +262,92 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
|||
x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md,
|
||||
h_lstm_md, hL_lstm_md, cL_lstm_md);
|
||||
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// lstm primitive description
|
||||
lstm_forward::primitive_desc lstm_prim_desc(lstm_desc, engine);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, mkldnn::memory> args;
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
// provide memory and check whether reorder is required
|
||||
// x
|
||||
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer());
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
|
||||
const bool xReorder = lstm_prim_desc.src_layer_desc() != x_user_mem.get_desc();
|
||||
auto x_lstm_mem = xReorder ? mkldnn::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem;
|
||||
auto x_lstm_mem = xReorder ? dnnl::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
reorder(x_user_mem, x_lstm_mem).execute(stream, x_user_mem, x_lstm_mem);
|
||||
args[MKLDNN_ARG_SRC_LAYER] = x_lstm_mem;
|
||||
args[DNNL_ARG_SRC_LAYER] = x_lstm_mem;
|
||||
|
||||
// wx
|
||||
auto wx_user_mem = mkldnn::memory(wx_user_md, engine, Wx->getBuffer());
|
||||
auto wx_user_mem = dnnl::memory(wx_user_md, engine, Wx->getBuffer());
|
||||
const bool wxReorder = lstm_prim_desc.weights_layer_desc()!= wx_user_mem.get_desc();
|
||||
auto wx_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem;
|
||||
auto wx_lstm_mem = wxReorder ? dnnl::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem;
|
||||
if (wxReorder)
|
||||
reorder(wx_user_mem, wx_lstm_mem).execute(stream, wx_user_mem, wx_lstm_mem);
|
||||
args[MKLDNN_ARG_WEIGHTS_LAYER] = wx_lstm_mem;
|
||||
args[DNNL_ARG_WEIGHTS_LAYER] = wx_lstm_mem;
|
||||
|
||||
// wr
|
||||
auto wr_user_mem = mkldnn::memory(wr_user_md, engine, Wr->getBuffer());
|
||||
auto wr_user_mem = dnnl::memory(wr_user_md, engine, Wr->getBuffer());
|
||||
const bool wrReorder = lstm_prim_desc.weights_iter_desc() != wr_user_mem.get_desc();
|
||||
auto wr_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem;
|
||||
auto wr_lstm_mem = wxReorder ? dnnl::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem;
|
||||
if (wrReorder)
|
||||
reorder(wr_user_mem, wr_lstm_mem).execute(stream, wr_user_mem, wr_lstm_mem);
|
||||
args[MKLDNN_ARG_WEIGHTS_ITER] = wr_lstm_mem;
|
||||
args[DNNL_ARG_WEIGHTS_ITER] = wr_lstm_mem;
|
||||
|
||||
// h
|
||||
auto h_user_mem = mkldnn::memory(h_user_md, engine, h->getBuffer());
|
||||
auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer());
|
||||
const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc();
|
||||
auto h_lstm_mem = hReorder ? mkldnn::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem;
|
||||
args[MKLDNN_ARG_DST_LAYER] = h_lstm_mem;
|
||||
auto h_lstm_mem = hReorder ? dnnl::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem;
|
||||
args[DNNL_ARG_DST_LAYER] = h_lstm_mem;
|
||||
|
||||
// b
|
||||
if(b) {
|
||||
auto b_user_mem = mkldnn::memory(b_user_md, engine, b->getBuffer());
|
||||
auto b_user_mem = dnnl::memory(b_user_md, engine, b->getBuffer());
|
||||
const bool bReorder = lstm_prim_desc.bias_desc() != b_user_mem.get_desc();
|
||||
auto b_lstm_mem = bReorder ? mkldnn::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem;
|
||||
auto b_lstm_mem = bReorder ? dnnl::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem;
|
||||
if (bReorder)
|
||||
reorder(b_user_mem, b_lstm_mem).execute(stream, b_user_mem, b_lstm_mem);
|
||||
args[MKLDNN_ARG_BIAS] = b_lstm_mem;
|
||||
args[DNNL_ARG_BIAS] = b_lstm_mem;
|
||||
}
|
||||
|
||||
// hI
|
||||
if(hI) {
|
||||
auto hI_user_mem = mkldnn::memory(hI_user_md, engine, hI->getBuffer());
|
||||
auto hI_user_mem = dnnl::memory(hI_user_md, engine, hI->getBuffer());
|
||||
const bool hIReorder = lstm_prim_desc.src_iter_desc() != hI_user_mem.get_desc();
|
||||
auto hI_lstm_mem = hIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem;
|
||||
auto hI_lstm_mem = hIReorder ? dnnl::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem;
|
||||
if (hIReorder)
|
||||
reorder(hI_user_mem, hI_lstm_mem).execute(stream, hI_user_mem, hI_lstm_mem);
|
||||
args[MKLDNN_ARG_SRC_ITER] = hI_lstm_mem;
|
||||
args[DNNL_ARG_SRC_ITER] = hI_lstm_mem;
|
||||
}
|
||||
|
||||
// cI
|
||||
if(cI) {
|
||||
auto cI_user_mem = mkldnn::memory(cI_user_md, engine, cI->getBuffer());
|
||||
auto cI_user_mem = dnnl::memory(cI_user_md, engine, cI->getBuffer());
|
||||
const bool cIReorder = lstm_prim_desc.src_iter_c_desc() != cI_user_mem.get_desc();
|
||||
auto cI_lstm_mem = cIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem;
|
||||
auto cI_lstm_mem = cIReorder ? dnnl::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem;
|
||||
if (cIReorder)
|
||||
reorder(cI_user_mem, cI_lstm_mem).execute(stream, cI_user_mem, cI_lstm_mem);
|
||||
args[MKLDNN_ARG_SRC_ITER_C] = cI_lstm_mem;
|
||||
args[DNNL_ARG_SRC_ITER_C] = cI_lstm_mem;
|
||||
}
|
||||
|
||||
bool hLReorder(false), cLReorder(false);
|
||||
mkldnn::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem;
|
||||
dnnl::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem;
|
||||
|
||||
// hL
|
||||
if(hL) {
|
||||
hL_user_mem = mkldnn::memory(hL_user_md, engine, hL->getBuffer());
|
||||
hL_user_mem = dnnl::memory(hL_user_md, engine, hL->getBuffer());
|
||||
hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc();
|
||||
hL_lstm_mem = hLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem;
|
||||
args[MKLDNN_ARG_DST_ITER] = hL_lstm_mem;
|
||||
hL_lstm_mem = hLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem;
|
||||
args[DNNL_ARG_DST_ITER] = hL_lstm_mem;
|
||||
}
|
||||
|
||||
// cL
|
||||
if(cL) {
|
||||
cL_user_mem = mkldnn::memory(cL_user_md, engine, cL->getBuffer());
|
||||
cL_user_mem = dnnl::memory(cL_user_md, engine, cL->getBuffer());
|
||||
cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc();
|
||||
cL_lstm_mem = cLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem;
|
||||
args[MKLDNN_ARG_DST_ITER_C] = cL_lstm_mem;
|
||||
cL_lstm_mem = cLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem;
|
||||
args[DNNL_ARG_DST_ITER_C] = cL_lstm_mem;
|
||||
}
|
||||
|
||||
// run calculations
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -82,11 +82,11 @@ namespace nd4j {
|
|||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
int extraParam0 = 1;
|
||||
|
||||
mkldnn_memory_desc_t empty;
|
||||
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
mkldnn::algorithm algorithm;
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
|
@ -102,23 +102,23 @@ namespace nd4j {
|
|||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
|
||||
auto pool_src_memory = user_src_memory;
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = user_dst_memory;
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
|
||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
|
||||
{MKLDNN_ARG_DST, pool_dst_memory}});
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory}});
|
||||
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -89,11 +89,11 @@ namespace nd4j {
|
|||
|
||||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
|
||||
mkldnn_memory_desc_t empty;
|
||||
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
mkldnn::algorithm algorithm;
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
|
@ -109,44 +109,44 @@ namespace nd4j {
|
|||
pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
|
||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||
auto userB_src_memory = mkldnn::memory(user_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
|
||||
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||
|
||||
auto poolB_src_memory = userB_src_memory;
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
|
||||
auto poolB_dst_memory = userB_dst_memory;
|
||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||
}
|
||||
|
||||
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto pool_src_memory = user_src_memory;
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
|
||||
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_desc(), engine);
|
||||
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
|
||||
{MKLDNN_ARG_DST, pool_dst_memory},
|
||||
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory}});
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
||||
// probably wrong, fix that
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory},
|
||||
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
|
||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -87,11 +87,11 @@ namespace nd4j {
|
|||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
auto extraParam0 = 1;
|
||||
|
||||
mkldnn_memory_desc_t empty;
|
||||
mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
mkldnn::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
mkldnn::algorithm algorithm;
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||
extraParam0, true,
|
||||
|
@ -106,24 +106,24 @@ namespace nd4j {
|
|||
pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
|
||||
auto pool_src_memory = user_src_memory;
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = user_dst_memory;
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
|
||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
|
||||
{MKLDNN_ARG_DST, pool_dst_memory}});
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory}});
|
||||
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -93,11 +93,11 @@ namespace nd4j {
|
|||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
auto extraParam0 = 1;
|
||||
|
||||
mkldnn_memory_desc_t empty;
|
||||
mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
mkldnn::algorithm algorithm;
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||
extraParam0, true,
|
||||
|
@ -115,44 +115,44 @@ namespace nd4j {
|
|||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
mkldnn::stream stream(engine);
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
|
||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||
auto userB_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer());
|
||||
auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||
|
||||
auto poolB_src_memory = userB_src_memory;
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
|
||||
auto poolB_dst_memory = userB_dst_memory;
|
||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||
poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||
}
|
||||
|
||||
|
||||
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
|
||||
auto pool_src_memory = user_src_memory;
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine);
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine);
|
||||
auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_desc(), engine);
|
||||
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory},
|
||||
{MKLDNN_ARG_DST, pool_dst_memory},
|
||||
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory}});
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{MKLDNN_ARG_WORKSPACE, pool_workspace_memory},
|
||||
{MKLDNN_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
|
||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
|
||||
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
|
|
|
@ -18,23 +18,23 @@
|
|||
// @author saudet
|
||||
//
|
||||
|
||||
#include <mkldnn_types.h>
|
||||
#include <dnnl_types.h>
|
||||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace mkldnnUtils {
|
||||
void getMKLDNNMemoryDescPool2d(
|
||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
|
||||
mkldnn::memory::dims pool_src_tz = { bS, iC, iH, iW };
|
||||
mkldnn::memory::dims pool_dst_tz = { bS, oC, oH, oW };
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
||||
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
|
||||
dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW };
|
||||
dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW };
|
||||
|
||||
pool_strides = { sH, sW };
|
||||
pool_kernel = { kH, kW };
|
||||
|
@ -45,14 +45,14 @@ namespace nd4j {
|
|||
algorithm = poolingMode == 0 ? algorithm::pooling_max
|
||||
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
|
||||
: algorithm::pooling_avg_include_padding;
|
||||
auto type = mkldnn::memory::data_type::f32;
|
||||
auto format = isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
|
||||
auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||
|
||||
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
||||
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
|
||||
|
@ -60,9 +60,9 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
||||
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
||||
|
@ -70,9 +70,9 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
||||
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
||||
|
@ -84,12 +84,12 @@ namespace nd4j {
|
|||
void getMKLDNNMemoryDescPool3d(
|
||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
|
||||
mkldnn::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
|
||||
mkldnn::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
||||
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
|
||||
dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
|
||||
dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
|
||||
|
||||
pool_strides = { sD, sH, sW };
|
||||
pool_kernel = { kD, kH, kW };
|
||||
|
@ -101,14 +101,14 @@ namespace nd4j {
|
|||
algorithm = poolingMode == 0 ? algorithm::pooling_max
|
||||
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
|
||||
: algorithm::pooling_avg_include_padding;
|
||||
auto type = mkldnn::memory::data_type::f32;
|
||||
auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc;
|
||||
auto supposed_to_be_any_format = mkldnn::memory::format_tag::nCdhw8c; // doesn't work with "any"
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any"
|
||||
|
||||
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
||||
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
|
@ -117,9 +117,9 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
||||
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
|
@ -128,9 +128,9 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
||||
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
||||
|
@ -145,15 +145,15 @@ namespace nd4j {
|
|||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
|
||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation) {
|
||||
mkldnn::memory::dims conv_src_tz = { bS, iC, iH, iW };
|
||||
mkldnn::memory::dims conv_weights_tz = { oC, iC, kH, kW };
|
||||
mkldnn::memory::dims conv_bias_tz = { oC };
|
||||
mkldnn::memory::dims conv_dst_tz = { bS, oC, oH, oW };
|
||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
||||
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
|
||||
dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW };
|
||||
dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW };
|
||||
dnnl::memory::dims conv_bias_tz = { oC };
|
||||
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
|
||||
|
||||
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
|
||||
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(iH, iW, oH, oW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
|
||||
|
@ -169,14 +169,14 @@ namespace nd4j {
|
|||
conv_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
||||
(oW - 1) * sW - iW + kW - pW };
|
||||
|
||||
auto type = mkldnn::memory::data_type::f32;
|
||||
auto format = isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
|
||||
auto formatw = mkldnn::memory::format_tag::hwio;
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
auto formatw = dnnl::memory::format_tag::hwio;
|
||||
|
||||
if (src != nullptr && conv_src_md != nullptr) {
|
||||
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
|
||||
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
|
||||
|
@ -184,9 +184,9 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
||||
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
|
||||
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
||||
|
@ -194,9 +194,9 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (weights != nullptr && conv_weights_md != nullptr) {
|
||||
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
|
||||
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = hwio"
|
||||
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
|
||||
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3];
|
||||
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2];
|
||||
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
|
||||
|
@ -204,9 +204,9 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
||||
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
|
||||
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_diff_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = hwio"
|
||||
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
|
||||
|
@ -214,14 +214,14 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (bias != nullptr && conv_bias_md != nullptr) {
|
||||
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::any);
|
||||
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::x);
|
||||
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
|
||||
}
|
||||
|
||||
if (dst != nullptr && conv_dst_md != nullptr) {
|
||||
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format_tag::any);
|
||||
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
||||
|
@ -233,15 +233,15 @@ namespace nd4j {
|
|||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
|
||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation) {
|
||||
mkldnn::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
|
||||
mkldnn::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
|
||||
mkldnn::memory::dims conv_bias_tz = { oC };
|
||||
mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
|
||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
||||
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
|
||||
dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
|
||||
dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
|
||||
dnnl::memory::dims conv_bias_tz = { oC };
|
||||
dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
|
||||
|
||||
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
|
||||
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
|
||||
|
@ -251,14 +251,14 @@ namespace nd4j {
|
|||
conv_padding_r = { pDmkl, pHmkl, pWmkl };
|
||||
conv_dilation = { dDmkl, dHmkl, dWmkl };
|
||||
|
||||
auto type = mkldnn::memory::data_type::f32;
|
||||
auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc;
|
||||
auto formatw = mkldnn::memory::format_tag::dhwio;
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
auto formatw = dnnl::memory::format_tag::dhwio;
|
||||
|
||||
if (src != nullptr && conv_src_md != nullptr) {
|
||||
*conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
|
||||
*user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
|
@ -267,9 +267,9 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
||||
*conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any);
|
||||
*user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
|
@ -278,9 +278,9 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (weights != nullptr && conv_weights_md != nullptr) {
|
||||
*conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
|
||||
*user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = dhwio"
|
||||
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
|
||||
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4];
|
||||
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3];
|
||||
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
|
||||
|
@ -289,9 +289,9 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
||||
*conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any);
|
||||
*user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_diff_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = dhwio"
|
||||
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
|
||||
|
@ -300,14 +300,14 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (bias != nullptr && conv_bias_md != nullptr) {
|
||||
*conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::any);
|
||||
*user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::x);
|
||||
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
|
||||
}
|
||||
|
||||
if (dst != nullptr && conv_dst_md != nullptr) {
|
||||
*conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format_tag::any);
|
||||
*user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
||||
|
@ -318,23 +318,23 @@ namespace nd4j {
|
|||
|
||||
|
||||
// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
// mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
|
||||
// mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
|
||||
// dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
||||
// dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
||||
// const Nd4jLong* shape = src->getShapeInfo();
|
||||
// Nd4jLong rank = shape[0];
|
||||
// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||
// Nd4jLong dim2 = axis >= 2 ? 1 : 2;
|
||||
// Nd4jLong dim3 = axis >= 3 ? 2 : 3;
|
||||
// mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||
// dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||
|
||||
// auto type = mkldnn::memory::data_type::f32;
|
||||
// auto format = mkldnn::memory::format_tag::nchw;
|
||||
// auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||
// auto type = dnnl::memory::data_type::f32;
|
||||
// auto format = dnnl::memory::format_tag::nchw;
|
||||
// auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||
|
||||
// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
|
||||
// *batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_src_md->data.format_kind = mkldnn_blocked; // overrides format
|
||||
// *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_src_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||
// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||
// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||
|
@ -342,9 +342,9 @@ namespace nd4j {
|
|||
// }
|
||||
|
||||
// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
|
||||
// *batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format
|
||||
// *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||
|
@ -352,9 +352,9 @@ namespace nd4j {
|
|||
// }
|
||||
|
||||
// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
|
||||
// *batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_dst_md->data.format_kind = mkldnn_blocked; // overrides format
|
||||
// *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_dst_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||
// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||
// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||
|
@ -364,23 +364,23 @@ namespace nd4j {
|
|||
|
||||
|
||||
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
mkldnn::memory::desc* lrn_src_md, mkldnn::memory::desc* lrn_diff_src_md, mkldnn::memory::desc* lrn_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
|
||||
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
||||
const Nd4jLong* shape = src->getShapeInfo();
|
||||
long rank = shape[0];
|
||||
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||
long dim2 = axis >= 2 ? 1 : 2;
|
||||
long dim3 = axis >= 3 ? 2 : 3;
|
||||
mkldnn::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||
dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||
|
||||
auto type = mkldnn::memory::data_type::f32;
|
||||
auto format = axis == 1 ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
auto supposed_to_be_any_format = format; // doesn't work with "any"
|
||||
|
||||
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
|
||||
*lrn_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = mkldnn_blocked;
|
||||
*lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked;
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||
|
@ -388,9 +388,9 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
|
||||
*lrn_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = mkldnn_blocked;
|
||||
*lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked;
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||
|
@ -398,9 +398,9 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
|
||||
*lrn_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_dst_md->data.format_kind = mkldnn_blocked;
|
||||
*lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked;
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||
|
@ -408,8 +408,8 @@ namespace nd4j {
|
|||
}
|
||||
}
|
||||
|
||||
mkldnn::engine& getEngine(void *ptr) {
|
||||
auto eng = reinterpret_cast<mkldnn::engine*>(ptr);
|
||||
dnnl::engine& getEngine(void *ptr) {
|
||||
auto eng = reinterpret_cast<dnnl::engine*>(ptr);
|
||||
return *eng;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
#include <NativeOps.h>
|
||||
#include <NDArray.h>
|
||||
#include <mkldnn.hpp>
|
||||
#include <dnnl.hpp>
|
||||
#include <MKLDNNStream.h>
|
||||
#include <graph/Context.h>
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
|
@ -89,47 +89,47 @@ namespace nd4j{
|
|||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
|
||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation);
|
||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
||||
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation);
|
||||
|
||||
void getMKLDNNMemoryDescConv3d(
|
||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
|
||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation);
|
||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
||||
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation);
|
||||
|
||||
void getMKLDNNMemoryDescPool2d(
|
||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
||||
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r);
|
||||
|
||||
void getMKLDNNMemoryDescPool3d(
|
||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
||||
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r);
|
||||
|
||||
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis);
|
||||
dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
||||
|
||||
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
mkldnn::memory::desc* lrn_src_md, mkldnn::memory::desc* lrn_diff_src_md, mkldnn::memory::desc* lrn_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis);
|
||||
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
||||
|
||||
mkldnn::engine& getEngine(void *ptr);
|
||||
dnnl::engine& getEngine(void *ptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ if ("${BUILD_MKLDNN}")
|
|||
set(HAVE_MKLDNN 1)
|
||||
add_definitions("-DHAVE_MKLDNN")
|
||||
include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_SOURCE_DIR}/external/mklml_lnx_2019.0.3.20190220/include ${mkldnn_SOURCE_DIR})
|
||||
set(MKLDNN mkldnn)
|
||||
set(MKLDNN dnnl)
|
||||
endif()
|
||||
|
||||
# Download and unpack flatbuffers at configure time
|
||||
|
|
Loading…
Reference in New Issue