Softmax BP mkldnn implementation (#301)

* libnd4j mkldnn softmax_bp operation implementation and integration, 2 tests added, need some refactoring and code clean up and more testing with different input shapes

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j softmax_bp update, code refactoring, etc

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j merge master, fixed typos, minor tweaks, code clean up

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j  integrate mkldnnUtils helpers in other mkldnn operations

Signed-off-by: Oleg <oleg.semeniv@gmail.com>
master
Oleh 2020-03-12 17:25:29 +02:00 committed by GitHub
parent 88f39fad67
commit 41bde8f885
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 387 additions and 567 deletions

View File

@ -90,32 +90,13 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
// x // x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
if(x->ews() != 1 || x->ordering() != 'c') {
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = x->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = x->strideAt(1);
if(xRank > 2) {
x_user_md.data.format_desc.blocking.strides[2] = x->strideAt(2);
x_user_md.data.format_desc.blocking.strides[3] = x->strideAt(3);
}
if(xRank > 4)
x_user_md.data.format_desc.blocking.strides[4] = x->strideAt(4);
}
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
// z, output // z, output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
if(z->ews() != 1 || z->ordering() != 'c') {
z_user_md.data.format_kind = dnnl_blocked; // overrides format mkldnnUtils::setBlockStrides(z, xRank, z_user_md);
z_user_md.data.format_desc.blocking.strides[0] = z->strideAt(0);
z_user_md.data.format_desc.blocking.strides[1] = z->strideAt(1);
if(xRank > 2) {
z_user_md.data.format_desc.blocking.strides[2] = z->strideAt(2);
z_user_md.data.format_desc.blocking.strides[3] = z->strideAt(3);
}
if(xRank > 4)
z_user_md.data.format_desc.blocking.strides[4] = z->strideAt(4);
}
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -131,14 +112,9 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
// provide memory and check whether reorder is required // provide memory and check whether reorder is required
// x // x
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer()); mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem; // z
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// z
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc(); const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem; auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem;
@ -230,47 +206,20 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
// x // x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
if(x->ews() != 1 || x->ordering() != 'c') {
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = x->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = x->strideAt(1);
if(xRank > 2) {
x_user_md.data.format_desc.blocking.strides[2] = x->strideAt(2);
x_user_md.data.format_desc.blocking.strides[3] = x->strideAt(3);
}
if(xRank > 4)
x_user_md.data.format_desc.blocking.strides[4] = x->strideAt(4);
}
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
// dLdO // dLdO
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
if(dLdO->ews() != 1 || dLdO->ordering() != 'c') {
dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format mkldnnUtils::setBlockStrides(dLdO, xRank, dLdO_user_md);
dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->strideAt(0);
dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->strideAt(1);
if(xRank > 2) {
dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->strideAt(2);
dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->strideAt(3);
}
if(xRank > 4)
dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->strideAt(4);
}
// dLdI // dLdI
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
if(dLdI->ews() != 1 || dLdI->ordering() != 'c') {
dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format mkldnnUtils::setBlockStrides(dLdI, xRank, dLdI_user_md);
dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->strideAt(0);
dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->strideAt(1);
if(xRank > 2) {
dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->strideAt(2);
dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->strideAt(3);
}
if(xRank > 4)
dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->strideAt(4);
}
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -290,20 +239,10 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
// provide memory and check whether reorder is required // provide memory and check whether reorder is required
// x // x
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer()); mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// dLdO // dLdO
auto dLdO_user_mem = dnnl::memory(dLdO_user_md, engine, dLdO->getBuffer()); mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, args, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST);
const bool dLdOReorder = op_bp_prim_desc.diff_dst_desc() != dLdO_user_mem.get_desc();
auto dLdO_mkl_mem = dLdOReorder ? dnnl::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdO_user_mem;
if (dLdOReorder)
dnnl::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem);
args[DNNL_ARG_DIFF_DST] = dLdO_mkl_mem;
// mean // mean
auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());

View File

@ -67,13 +67,7 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
if(input->ews() != 1 || input->ordering() != 'c') { mkldnnUtils::setBlockStrides(input, 4, x_user_md);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
}
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
@ -92,13 +86,8 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
// output // output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
if(output->ews() != 1 || output->ordering() != 'c') {
z_user_md.data.format_kind = dnnl_blocked; // overrides format mkldnnUtils::setBlockStrides(output, 4, z_user_md);
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
}
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -114,20 +103,10 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// weights // weights
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// bias // bias
if(bias != nullptr) { if(bias != nullptr) {
@ -185,13 +164,7 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
if(input->ews() != 1 || input->ordering() != 'c') { mkldnnUtils::setBlockStrides(input, 4, x_user_md);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
}
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
@ -205,25 +178,13 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
// gradO // gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
if(gradO->ews() != 1 || gradO->ordering() != 'c') { mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md);
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
}
// gradI // gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
if(gradI->ews() != 1 || gradI->ordering() != 'c') { mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md);
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
}
// gradW // gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat);
@ -260,20 +221,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// weights // weights
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// gradO // gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());

View File

@ -70,14 +70,7 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
if(input->ews() != 1 || input->ordering() != 'c') { mkldnnUtils::setBlockStrides(input, 5, x_user_md);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
}
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
@ -97,14 +90,7 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
// output // output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
if(output->ews() != 1 || output->ordering() != 'c') { mkldnnUtils::setBlockStrides(output, 5, z_user_md);
z_user_md.data.format_kind = dnnl_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(4);
}
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -120,21 +106,11 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// weights // weights
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// bias // bias
if(bias != nullptr) { if(bias != nullptr) {
auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer()); auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer());
@ -194,14 +170,7 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
if(input->ews() != 1 || input->ordering() != 'c') { mkldnnUtils::setBlockStrides(input, 5, x_user_md);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
}
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
@ -216,26 +185,14 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
// gradO // gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format mkldnnUtils::setBlockStrides(gradO, 5, gradO_user_md);
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(4);
}
// gradI // gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format mkldnnUtils::setBlockStrides(gradI, 5, gradI_user_md);
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(4);
}
// gradW // gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
@ -274,20 +231,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// weights // weights
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// gradO // gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());

View File

@ -88,13 +88,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
if(input->ews() != 1 || input->ordering() != 'c') { mkldnnUtils::setBlockStrides(input, 4, x_user_md);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
}
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
@ -113,13 +107,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
// output // output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
if(output->ews() != 1 || output->ordering() != 'c') { mkldnnUtils::setBlockStrides(output, 4, z_user_md);
z_user_md.data.format_kind = dnnl_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
}
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -136,20 +124,10 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// weights // weights
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// bias // bias
if(bias != nullptr) { if(bias != nullptr) {
@ -216,13 +194,7 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
if(input->ews() != 1 || input->ordering() != 'c') { mkldnnUtils::setBlockStrides(input, 4, x_user_md);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
}
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
@ -236,24 +208,12 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
// gradO // gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
if(gradO->ews() != 1 || gradO->ordering() != 'c') { mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md);
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
}
// gradI // gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
if(gradI->ews() != 1 || gradI->ordering() != 'c') { mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md);
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
}
// gradW // gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
@ -291,20 +251,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// weights // weights
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// gradO // gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());

View File

@ -76,24 +76,12 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
// gradO // gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
if(gradO->ews() != 1 || gradO->ordering() != 'c') { mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md);
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
}
// gradI // gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
if(gradI->ews() != 1 || gradI->ordering() != 'c') { mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md);
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
}
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -113,20 +101,10 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// weights // weights
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// gradO // gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); mkldnnUtils::loadDataToMklStream(gradO, engine, stream, args, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST);
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
if (gradOReorder)
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
// gradI // gradI
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
@ -146,8 +124,6 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
// shape::printArray(z_mkl_mem.map_data<float>(),8); // shape::printArray(z_mkl_mem.map_data<float>(),8);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {

View File

@ -89,14 +89,7 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
if(input->ews() != 1 || input->ordering() != 'c') { mkldnnUtils::setBlockStrides(input, 5, x_user_md);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
}
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
@ -116,14 +109,7 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
// output // output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
if(output->ews() !=1 || output->ordering() != 'c') { mkldnnUtils::setBlockStrides(output, 5, z_user_md);
z_user_md.data.format_kind = dnnl_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(4);
}
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -140,20 +126,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// weights // weights
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// bias // bias
if(bias != nullptr) { if(bias != nullptr) {
@ -223,14 +199,7 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
if(input->ews() != 1 || input->ordering() != 'c') { mkldnnUtils::setBlockStrides(input, 5, x_user_md);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
}
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
@ -245,26 +214,12 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
// gradO // gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
if(gradO->ews() != 1 || gradO->ordering() != 'c') { mkldnnUtils::setBlockStrides(gradO, 5, gradO_user_md);
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(4);
}
// gradI // gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
if(gradI->ews() != 1 || gradI->ordering() != 'c') { mkldnnUtils::setBlockStrides(gradI, 5, gradI_user_md);
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(4);
}
// gradW // gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat); dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat);
@ -304,20 +259,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// weights // weights
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// gradO // gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());

View File

@ -98,13 +98,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
if(input->ews() != 1 || input->ordering() != 'c') { mkldnnUtils::setBlockStrides(input, 4, x_user_md);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); // do permutation NHWC -> NCHW
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
}
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; // weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
@ -124,13 +118,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
// output // output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat);
if(output->ews() != 1 || output->ordering() != 'c') { mkldnnUtils::setBlockStrides(output, 4, z_user_md);
z_user_md.data.format_kind = dnnl_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); // do permutation NHWC -> NCHW
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
}
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -147,20 +135,10 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// weights // weights
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// bias // bias
if(bias != nullptr) { if(bias != nullptr) {
@ -235,13 +213,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
if(input->ews() != 1 || input->ordering() != 'c') { mkldnnUtils::setBlockStrides(input, 4, x_user_md);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
}
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; // weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
@ -256,24 +228,12 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
// gradO // gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFrmat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFrmat);
if(gradO->ews() != 1 || gradO->ordering() != 'c') { mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md);
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
}
// gradI // gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFrmat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFrmat);
if(gradI->ews() != 1 || gradI->ordering() != 'c') { mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md);
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
}
// gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; // gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
@ -312,20 +272,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// weights // weights
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
// gradO // gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());

View File

@ -272,29 +272,14 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
// provide memory and check whether reorder is required // provide memory and check whether reorder is required
// x // x
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer()); mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, lstm_prim_desc.src_layer_desc(), DNNL_ARG_SRC_LAYER);
const bool xReorder = lstm_prim_desc.src_layer_desc() != x_user_mem.get_desc();
auto x_lstm_mem = xReorder ? dnnl::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem;
if (xReorder)
reorder(x_user_mem, x_lstm_mem).execute(stream, x_user_mem, x_lstm_mem);
args[DNNL_ARG_SRC_LAYER] = x_lstm_mem;
// wx // wx
auto wx_user_mem = dnnl::memory(wx_user_md, engine, Wx->getBuffer()); mkldnnUtils::loadDataToMklStream(Wx, engine, stream, args, wx_user_md, lstm_prim_desc.weights_layer_desc(), DNNL_ARG_WEIGHTS_LAYER);
const bool wxReorder = lstm_prim_desc.weights_layer_desc()!= wx_user_mem.get_desc();
auto wx_lstm_mem = wxReorder ? dnnl::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem;
if (wxReorder)
reorder(wx_user_mem, wx_lstm_mem).execute(stream, wx_user_mem, wx_lstm_mem);
args[DNNL_ARG_WEIGHTS_LAYER] = wx_lstm_mem;
// wr // wr
auto wr_user_mem = dnnl::memory(wr_user_md, engine, Wr->getBuffer()); mkldnnUtils::loadDataToMklStream(Wr, engine, stream, args, wr_user_md, lstm_prim_desc.weights_iter_desc(), DNNL_ARG_WEIGHTS_ITER);
const bool wrReorder = lstm_prim_desc.weights_iter_desc() != wr_user_mem.get_desc();
auto wr_lstm_mem = wxReorder ? dnnl::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem;
if (wrReorder)
reorder(wr_user_mem, wr_lstm_mem).execute(stream, wr_user_mem, wr_lstm_mem);
args[DNNL_ARG_WEIGHTS_ITER] = wr_lstm_mem;
// h // h
auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer()); auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer());
const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc(); const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc();
@ -303,32 +288,17 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
// b // b
if(b) { if(b) {
auto b_user_mem = dnnl::memory(b_user_md, engine, b->getBuffer()); mkldnnUtils::loadDataToMklStream(b, engine, stream, args, b_user_md, lstm_prim_desc.bias_desc(), DNNL_ARG_BIAS);
const bool bReorder = lstm_prim_desc.bias_desc() != b_user_mem.get_desc();
auto b_lstm_mem = bReorder ? dnnl::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem;
if (bReorder)
reorder(b_user_mem, b_lstm_mem).execute(stream, b_user_mem, b_lstm_mem);
args[DNNL_ARG_BIAS] = b_lstm_mem;
} }
// hI // hI
if(hI) { if(hI) {
auto hI_user_mem = dnnl::memory(hI_user_md, engine, hI->getBuffer()); mkldnnUtils::loadDataToMklStream(hI, engine, stream, args, hI_user_md, lstm_prim_desc.src_iter_desc(), DNNL_ARG_SRC_ITER);
const bool hIReorder = lstm_prim_desc.src_iter_desc() != hI_user_mem.get_desc();
auto hI_lstm_mem = hIReorder ? dnnl::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem;
if (hIReorder)
reorder(hI_user_mem, hI_lstm_mem).execute(stream, hI_user_mem, hI_lstm_mem);
args[DNNL_ARG_SRC_ITER] = hI_lstm_mem;
} }
// cI // cI
if(cI) { if(cI) {
auto cI_user_mem = dnnl::memory(cI_user_md, engine, cI->getBuffer()); mkldnnUtils::loadDataToMklStream(cI, engine, stream, args, cI_user_md, lstm_prim_desc.src_iter_c_desc(), DNNL_ARG_SRC_ITER_C);
const bool cIReorder = lstm_prim_desc.src_iter_c_desc() != cI_user_mem.get_desc();
auto cI_lstm_mem = cIReorder ? dnnl::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem;
if (cIReorder)
reorder(cI_user_mem, cI_lstm_mem).execute(stream, cI_user_mem, cI_lstm_mem);
args[DNNL_ARG_SRC_ITER_C] = cI_lstm_mem;
} }
bool hLReorder(false), cLReorder(false); bool hLReorder(false), cLReorder(false);

View File

@ -163,21 +163,25 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(xTR, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
/*
auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->getBuffer()); auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->getBuffer());
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder) if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem; args[DNNL_ARG_SRC] = x_mkl_mem;
*/
// y // y
mkldnnUtils::loadDataToMklStream(yTR, engine, stream, args, y_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
/*
auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->getBuffer()); auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->getBuffer());
const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc(); const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc();
auto y_mkl_mem = yReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : y_user_mem; auto y_mkl_mem = yReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : y_user_mem;
if (yReorder) if (yReorder)
dnnl::reorder(y_user_mem, y_mkl_mem).execute(stream, y_user_mem, y_mkl_mem); dnnl::reorder(y_user_mem, y_mkl_mem).execute(stream, y_user_mem, y_mkl_mem);
args[DNNL_ARG_WEIGHTS] = y_mkl_mem; args[DNNL_ARG_WEIGHTS] = y_mkl_mem;
*/
// z // z
auto z_user_mem = dnnl::memory(z_user_md, engine, zR->getBuffer()); auto z_user_mem = dnnl::memory(z_user_md, engine, zR->getBuffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();

View File

@ -28,6 +28,55 @@ using namespace dnnl;
namespace sd { namespace sd {
namespace mkldnnUtils { namespace mkldnnUtils {
//////////////////////////////////////////////////////////////////////
void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims){
std::vector<int64_t> vDims(rank);
for (auto i = 0; i < rank; i++) {
vDims[i] = array->sizeAt(i);
}
mklDims = dnnl::memory::dims(vDims);
}
//////////////////////////////////////////////////////////////////////
dnnl::memory::format_tag getFormat(const int rank){
if (2 == rank) {
return dnnl::memory::format_tag::ab;
}
else if (3 == rank) {
return dnnl::memory::format_tag::abc;
}
else if (4 == rank) {
return dnnl::memory::format_tag::abcd;
}
else if (5 == rank) {
return dnnl::memory::format_tag::abcde;
}
else if (6 == rank) {
return dnnl::memory::format_tag::abcdef;
}
return dnnl::memory::format_tag::a; // 1 == dataSetRank
}
//////////////////////////////////////////////////////////////////////
void setBlockStrides(const NDArray* array, const int rank, dnnl::memory::desc& mklMd){
if (array->ews() != 1 || array->ordering() != 'c') {
mklMd.data.format_kind = dnnl_blocked; // overrides format
for (auto i = 0; i < rank; ++i) {
mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////
void loadDataToMklStream(const NDArray* array, dnnl::engine& engine, dnnl::stream& stream,
std::unordered_map<int, dnnl::memory>& args, dnnl::memory::desc& user_md, dnnl::memory::desc primitive_md, int DNNL_ARG ){
auto user_mem = dnnl::memory(user_md, engine, array->getBuffer());
const bool bReorder = primitive_md != user_mem.get_desc();
auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem;
if (bReorder)
dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem);
args[DNNL_ARG] = mkl_mem;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
void poolingMKLDNN(const NDArray *input, NDArray *output, void poolingMKLDNN(const NDArray *input, NDArray *output,
const int kD, const int kH, const int kW, const int kD, const int kH, const int kW,
@ -113,12 +162,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output,
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// output // output
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer()); auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
@ -236,13 +280,8 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
std::unordered_map<int, dnnl::memory> args; std::unordered_map<int, dnnl::memory> args;
// gradO // gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); mkldnnUtils::loadDataToMklStream(gradO, engine, stream, args, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST);
const bool gradOReorder = op_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
if (gradOReorder)
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
// gradI // gradI
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
@ -252,13 +291,8 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
if(mode == algorithm::pooling_max) { if(mode == algorithm::pooling_max) {
// input // input
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// z // z
auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine); auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine);
args[DNNL_ARG_DST] = z_mkl_mem; args[DNNL_ARG_DST] = z_mkl_mem;

View File

@ -90,6 +90,8 @@ namespace sd {
DECLARE_PLATFORM(softmax, ENGINE_CPU); DECLARE_PLATFORM(softmax, ENGINE_CPU);
DECLARE_PLATFORM(softmax_bp, ENGINE_CPU);
DECLARE_PLATFORM(tanh, ENGINE_CPU); DECLARE_PLATFORM(tanh, ENGINE_CPU);
} }
@ -107,6 +109,41 @@ namespace sd {
dnnl::engine& getEngine(void* ptr); dnnl::engine& getEngine(void* ptr);
/**
* This function creates memory dimentions
* @param const pointer to array
* @param const array rank
* @param reference to memory dimentions
*/
void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims);
/**
* This function generate memory format tag based on rank
* @param const array rank
* @return memory format
*/
dnnl::memory::format_tag getFormat(const int rank);
/**
* This function generate memory format tag based on rank
* @param const pointer to dataset
* @param const dataset rank
* @param reference to memory descriptor
* @return memory format
*/
void setBlockStrides(const NDArray* array, const int rank, dnnl::memory::desc& mklMd);
//////////////////////////////////////////////////////////////////////
/**
* This function load and reorder user memory to mkl
* @param const pointer to dataset
* @param reference to mkl engine
* @param reference to mkl stream
* @param reference to args container for dnnl
* @param reference to user memory description
* @param primitive memory descriptor
* @param dnnl arg activation enumerator
*/
void loadDataToMklStream(const NDArray* array, dnnl::engine& engine, dnnl::stream& stream,
std::unordered_map<int, dnnl::memory>& args, dnnl::memory::desc& user_md, dnnl::memory::desc primitive_md, int DNNL_ARG);
/** /**
* Utility methods for MKLDNN * Utility methods for MKLDNN
*/ */

View File

@ -31,69 +31,36 @@ namespace sd {
namespace ops { namespace ops {
namespace platforms { namespace platforms {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) { static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) {
const auto xRank = x->rankOf(); const auto xRank = x->rankOf();
const auto zRank = z->rankOf(); dnnl::memory::dims xShape, zShape;
std::vector<int64_t> dimsX(xRank), dimsZ(zRank); mkldnnUtils::getDims(x, xRank, xShape);
for (auto i = 0; i < xRank; i++) { mkldnnUtils::getDims(z, xRank, zShape);
dimsX[i] = x->sizeAt(i);
dimsZ[i] = z->sizeAt(i);
}
dnnl::memory::dims xShape = dnnl::memory::dims(dimsX);
dnnl::memory::dims zShape = dnnl::memory::dims(dimsZ);
dnnl::memory::format_tag format = dnnl::memory::format_tag::a; // 1 == xRank dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
if (2 == xRank && 1 == axis) { // optimized cases
format = dnnl::memory::format_tag::ab; if (2 == xRank && 0 == axis) {
}
else if (2 == xRank && 0 == axis) {
format = dnnl::memory::format_tag::ba; format = dnnl::memory::format_tag::ba;
} }
else if (3 == xRank) { else if (4 == xRank && 1 == axis && (x->sizeAt(2) * x->sizeAt(3)) > 1) {
format = dnnl::memory::format_tag::abc;
}
else if (4 == xRank && 3 == axis) {
format = dnnl::memory::format_tag::abcd;
}
else if (4 == xRank && 1 == axis && dimsX[2] * dimsX[3] > 1) {
format = dnnl::memory::format_tag::acdb; format = dnnl::memory::format_tag::acdb;
} }
else if (4 == xRank) {
format = dnnl::memory::format_tag::abcd;
}
else if (5 == xRank) {
format = dnnl::memory::format_tag::abcde;
}
else if (6 == xRank) {
format = dnnl::memory::format_tag::abcdef;
}
dnnl::memory::data_type xType = dnnl::memory::data_type::f32; dnnl::memory::data_type xType = dnnl::memory::data_type::f32;
dnnl::memory::data_type zType = dnnl::memory::data_type::f32;
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
if (x->ews() != 1 || x->ordering() != 'c') {
x_user_md.data.format_kind = dnnl_blocked; // overrides format
for (auto i = 0; i < xRank; ++i) {
x_user_md.data.format_desc.blocking.strides[i] = x->strideAt(i);
}
}
// z // z
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format); dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format);
if (z->ews() != 1 || z->ordering() != 'c') { mkldnnUtils::setBlockStrides(z, xRank, z_user_md);
z_user_md.data.format_kind = dnnl_blocked; // overrides format
for (auto i = 0; i < xRank; ++i) {
z_user_md.data.format_desc.blocking.strides[i] = z->strideAt(i);
}
}
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -101,7 +68,6 @@ namespace sd {
dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0) dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
// operation primitive description // operation primitive description
// todo check this
dnnl::softmax_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis); dnnl::softmax_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis);
dnnl::softmax_forward::primitive_desc op_prim_desc(op_desc, attr, engine); dnnl::softmax_forward::primitive_desc op_prim_desc(op_desc, attr, engine);
@ -114,12 +80,7 @@ namespace sd {
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer()); mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// z // z
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
@ -178,6 +139,136 @@ namespace sd {
} }
//////////////////////////////////////////////////////////////////////
static void softmaxBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx, const int axis) {
const auto xRank = x->rankOf();
const auto dLdzRank = dLdz->rankOf();
dnnl::memory::dims xShape, dLdxShape, dLdzShape;
mkldnnUtils::getDims(x, xRank, xShape);
mkldnnUtils::getDims(dLdx, xRank, dLdxShape);
mkldnnUtils::getDims(dLdz, dLdzRank, dLdzShape);
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
// x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
// dLdx
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(dLdx, xRank, dLdx_user_md);
// todo if mkl does not support broadcast we can remove this
format = mkldnnUtils::getFormat(dLdzRank);
// dLdz
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(dLdz, dLdzRank, dLdz_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// operation primitive description
// forward description
dnnl::softmax_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis);
dnnl::softmax_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// backward description
dnnl::softmax_backward::desc op_bp_desc(dLdz_mkl_md, dLdx_mkl_md, axis);
dnnl::softmax_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, dnnl::memory> argsbp, argsff;
dnnl::stream stream(engine);
// provide memory buffers and check whether reorder is required for forward
// input
mkldnnUtils::loadDataToMklStream(x, engine, stream, argsff, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC);
// dLdx
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer());
const bool dLdxReorder = op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc();
auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : dLdx_user_mem;
argsff[DNNL_ARG_DST] = dLdx_mkl_mem;
// check and arg set for backprob
argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
argsbp[DNNL_ARG_DST] = dLdx_mkl_mem;
// dLdz
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, argsbp, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST);
// run calculations forward
dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff);
// run calculations backward
dnnl::softmax_backward(op_bp_prim_desc).execute(stream, argsbp);
// reorder outputs if necessary
if (dLdxReorder)
dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem);
stream.wait();
}
PLATFORM_IMPL(softmax_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto dLdz = INPUT_VARIABLE(1);
auto dLdx = OUTPUT_VARIABLE(0);
const int rank = input->rankOf();
const int dLdzRank = dLdz->rankOf();
int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1;
if (dim < 0) {
dim += rank;
}
REQUIRE_TRUE(dim < rank && dim >= 0, 0, "SOFTMAX_MKLDNN_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim);
REQUIRE_TRUE(rank <= 6 && dLdzRank <= 6, 0, "SOFTMAX_MKLDNN_BP OP: the rank of input and dLdz must be less or qual 6, but got input rank = %i and dLdz rank rank = %i instead !", rank, dLdzRank);
// mkldnnSoftMax
softmaxBpMKLDNN(input, dLdz, dLdx, dim);
return Status::OK();
}
PLATFORM_CHECK(softmax_bp, ENGINE_CPU) {
auto x = INPUT_VARIABLE(0);
auto dLdz = INPUT_VARIABLE(1);
auto dLdx = OUTPUT_VARIABLE(0);
const DataType xType = x->dataType();
const DataType dLdzType = dLdz->dataType();
const DataType dLdxType = dLdx->dataType();
const int xRank = x->rankOf();
const int dLdzRank = dLdz->rankOf();
bool bSupportedRanks = xRank < 7 && dLdzRank == xRank && (!x->isEmpty() && !dLdz->isEmpty());
if (bSupportedRanks) {
for (int i = 0; i < xRank; i++) {
if (x->sizeAt(i) != dLdz->sizeAt(i)) {
bSupportedRanks = false;
break;
}
}
}
//Source Destination
//f32 f32
return block.isUseMKLDNN() && bSupportedRanks && (xType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32);
}
} }
} }
} }

View File

@ -35,52 +35,21 @@ namespace sd {
static void tanhMKLDNN(const NDArray* x, NDArray* z) { static void tanhMKLDNN(const NDArray* x, NDArray* z) {
const auto xRank = x->rankOf(); const auto xRank = x->rankOf();
dnnl::memory::dims xShape, zShape;
std::vector<int64_t> dimsX(xRank), dimsZ(xRank); mkldnnUtils::getDims(x, xRank, xShape);
for (auto i = 0; i < xRank; i++) { mkldnnUtils::getDims(z, xRank, zShape);
dimsX[i] = x->sizeAt(i);
dimsZ[i] = z->sizeAt(i);
}
dnnl::memory::dims xShape = dnnl::memory::dims(dimsX); dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
dnnl::memory::dims zShape = dnnl::memory::dims(dimsZ);
dnnl::memory::format_tag format = dnnl::memory::format_tag::a;
if (2 == xRank) {
format = dnnl::memory::format_tag::ab;
}
else if (3 == xRank) {
format = dnnl::memory::format_tag::abc;
}
else if (4 == xRank) {
format = dnnl::memory::format_tag::abcd;
}
else if (5 == xRank) {
format = dnnl::memory::format_tag::abcde;
}
else if (6 == xRank) {
format = dnnl::memory::format_tag::abcdef;
}
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
if (x->ews() != 1 || x->ordering() != 'c') {
x_user_md.data.format_kind = dnnl_blocked; // overrides format
for (auto i = 0; i < xRank; ++i) {
x_user_md.data.format_desc.blocking.strides[i] = x->strideAt(i);
}
}
// z // z
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
if (z->ews() != 1 || z->ordering() != 'c') { mkldnnUtils::setBlockStrides(z, xRank, z_user_md);
z_user_md.data.format_kind = dnnl_blocked; // overrides format
for (auto i = 0; i < xRank; ++i) {
z_user_md.data.format_desc.blocking.strides[i] = z->strideAt(i);
}
}
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -99,12 +68,7 @@ namespace sd {
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer()); mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// z // z
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());

View File

@ -76,3 +76,63 @@ TEST_F(DeclarableOpsTests18, test_tanh_2) {
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
} }
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST) {
NDArray input('c', { 2, 2 }, { 1,2,3,4 }, DataType::FLOAT32);
NDArray epsilon('c', { 2, 2 }, { .1, .2, .3, .4 }, DataType::FLOAT32);
int axis = 1;
NDArray output('c', { 2, 2 }, DataType::FLOAT32);
NDArray exp('c', { 2, 2 }, { -0.019661, 0.019661, -0.019661, 0.019661 }, DataType::FLOAT32);
sd::ops::softmax_bp op;
Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis });
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(output.equalsTo(exp));
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST2) {
NDArray input('c', { 4, 5, 2, 3 }, DataType::FLOAT32);
NDArray epsilon('c', { 4, 5, 2, 3 }, { -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855 }, DataType::FLOAT32);
input.linspace(0.1, 0.2);
int axis = -1;
NDArray output('c', { 4, 5, 2, 3 }, DataType::FLOAT32);
NDArray exp('c', { 4, 5, 2, 3 }, { -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253 }, DataType::FLOAT32);
sd::ops::softmax_bp op;
Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis });
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(output.equalsTo(exp));
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST3) {
NDArray input('f', { 4, 5, 2, 3 }, DataType::FLOAT32);
NDArray epsilon('f', { 4, 5, 2, 3 }, { -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855 }, DataType::FLOAT32);
input.linspace(-5., 0.5);
int axis = 1;
NDArray output('f', { 4, 5, 2, 3 }, DataType::FLOAT32);
NDArray expC('c', { 4, 5, 2, 3 }, { -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909, -0.000000, 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, -0.000149, 0.000054, 0.000095, 0.000095, -0.000149, 0.000054, -0.001760, 0.002943, -0.001183, -0.001183, -0.001760, 0.002943, 0.001909, -0.002997, 0.001088, 0.001088, 0.001909, -0.002997, 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, 0.000000, 0.000054, 0.000095, -0.000149, -0.000149, 0.000054, 0.000095, 0.002943, -0.001183, -0.001760, -0.001760, 0.002943, -0.001183, -0.002997, 0.001088, 0.001909, 0.001909, -0.002997, 0.001088, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909 }, DataType::FLOAT32);
NDArray exp('f', { 4, 5, 2, 3 }, DataType::FLOAT32);
exp.assign(expC);
sd::ops::softmax_bp op;
Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis });
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(output.equalsTo(exp));
}

View File

@ -71,8 +71,10 @@ TEST_F(MklDnnTests, helpers_includer) {
sd::ops::platforms::PLATFORM_softmax_ENGINE_CPU softmax; sd::ops::platforms::PLATFORM_softmax_ENGINE_CPU softmax;
sd::ops::platforms::PLATFORM_softmax_bp_ENGINE_CPU softmax_bp;
sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh; sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh;
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &tanh }); printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh });
#endif #endif
} }