Softmax BP mkldnn implementation (#301)
* libnd4j mkldnn softmax_bp operation implementation and integration, 2 tests added, need some refactoring and code clean up and more testing with different input shapes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j softmax_bp update, code refactoring, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j merge master, fixed typos, minor tweaks, code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j integrate mkldnnUtils helpers in other mkldnn operations Signed-off-by: Oleg <oleg.semeniv@gmail.com>master
parent
88f39fad67
commit
41bde8f885
|
@ -90,32 +90,13 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
||||||
// x
|
// x
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
if(x->ews() != 1 || x->ordering() != 'c') {
|
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = x->strideAt(0);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = x->strideAt(1);
|
|
||||||
if(xRank > 2) {
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = x->strideAt(2);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = x->strideAt(3);
|
|
||||||
}
|
|
||||||
if(xRank > 4)
|
|
||||||
x_user_md.data.format_desc.blocking.strides[4] = x->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
|
||||||
// z, output
|
// z, output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
if(z->ews() != 1 || z->ordering() != 'c') {
|
|
||||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(z, xRank, z_user_md);
|
||||||
z_user_md.data.format_desc.blocking.strides[0] = z->strideAt(0);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[1] = z->strideAt(1);
|
|
||||||
if(xRank > 2) {
|
|
||||||
z_user_md.data.format_desc.blocking.strides[2] = z->strideAt(2);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[3] = z->strideAt(3);
|
|
||||||
}
|
|
||||||
if(xRank > 4)
|
|
||||||
z_user_md.data.format_desc.blocking.strides[4] = z->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -131,12 +112,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
||||||
// provide memory and check whether reorder is required
|
// provide memory and check whether reorder is required
|
||||||
|
|
||||||
// x
|
// x
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
|
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// z
|
// z
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
|
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
|
||||||
|
@ -230,47 +206,20 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
// x
|
// x
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
if(x->ews() != 1 || x->ordering() != 'c') {
|
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = x->strideAt(0);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = x->strideAt(1);
|
|
||||||
if(xRank > 2) {
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = x->strideAt(2);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = x->strideAt(3);
|
|
||||||
}
|
|
||||||
if(xRank > 4)
|
|
||||||
x_user_md.data.format_desc.blocking.strides[4] = x->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// dLdO
|
// dLdO
|
||||||
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
if(dLdO->ews() != 1 || dLdO->ordering() != 'c') {
|
|
||||||
dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(dLdO, xRank, dLdO_user_md);
|
||||||
dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->strideAt(0);
|
|
||||||
dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->strideAt(1);
|
|
||||||
if(xRank > 2) {
|
|
||||||
dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->strideAt(2);
|
|
||||||
dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->strideAt(3);
|
|
||||||
}
|
|
||||||
if(xRank > 4)
|
|
||||||
dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// dLdI
|
// dLdI
|
||||||
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
if(dLdI->ews() != 1 || dLdI->ordering() != 'c') {
|
|
||||||
dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(dLdI, xRank, dLdI_user_md);
|
||||||
dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->strideAt(0);
|
|
||||||
dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->strideAt(1);
|
|
||||||
if(xRank > 2) {
|
|
||||||
dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->strideAt(2);
|
|
||||||
dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->strideAt(3);
|
|
||||||
}
|
|
||||||
if(xRank > 4)
|
|
||||||
dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -290,20 +239,10 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
// provide memory and check whether reorder is required
|
// provide memory and check whether reorder is required
|
||||||
|
|
||||||
// x
|
// x
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
|
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// dLdO
|
// dLdO
|
||||||
auto dLdO_user_mem = dnnl::memory(dLdO_user_md, engine, dLdO->getBuffer());
|
mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, args, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST);
|
||||||
const bool dLdOReorder = op_bp_prim_desc.diff_dst_desc() != dLdO_user_mem.get_desc();
|
|
||||||
auto dLdO_mkl_mem = dLdOReorder ? dnnl::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdO_user_mem;
|
|
||||||
if (dLdOReorder)
|
|
||||||
dnnl::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem);
|
|
||||||
args[DNNL_ARG_DIFF_DST] = dLdO_mkl_mem;
|
|
||||||
|
|
||||||
// mean
|
// mean
|
||||||
auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());
|
auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());
|
||||||
|
|
|
@ -67,13 +67,7 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(input, 4, x_user_md);
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
|
@ -92,13 +86,8 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
// output
|
// output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||||
if(output->ews() != 1 || output->ordering() != 'c') {
|
|
||||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(output, 4, z_user_md);
|
||||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -114,20 +103,10 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||||
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
|
|
||||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
|
||||||
if (wReorder)
|
|
||||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
|
||||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
if(bias != nullptr) {
|
if(bias != nullptr) {
|
||||||
|
@ -185,13 +164,7 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(input, 4, x_user_md);
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
|
@ -205,24 +178,12 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md);
|
||||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md);
|
||||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
|
@ -260,20 +221,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||||
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
|
||||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
|
||||||
if (wReorder)
|
|
||||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
|
||||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
|
|
|
@ -70,14 +70,7 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(input, 5, x_user_md);
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
|
@ -97,14 +90,7 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
// output
|
// output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||||
if(output->ews() != 1 || output->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(output, 5, z_user_md);
|
||||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -120,20 +106,10 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||||
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
|
|
||||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
|
||||||
if (wReorder)
|
|
||||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
|
||||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
if(bias != nullptr) {
|
if(bias != nullptr) {
|
||||||
|
@ -194,14 +170,7 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(input, 5, x_user_md);
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
|
@ -216,26 +185,14 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
|
||||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(gradO, 5, gradO_user_md);
|
||||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
|
||||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(gradI, 5, gradI_user_md);
|
||||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
|
@ -274,20 +231,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||||
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
|
||||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
|
||||||
if (wReorder)
|
|
||||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
|
||||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
|
|
|
@ -88,13 +88,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(input, 4, x_user_md);
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
|
@ -113,13 +107,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// output
|
// output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
|
||||||
if(output->ews() != 1 || output->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(output, 4, z_user_md);
|
||||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -136,20 +124,10 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||||
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
|
|
||||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
|
||||||
if (wReorder)
|
|
||||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
|
||||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
if(bias != nullptr) {
|
if(bias != nullptr) {
|
||||||
|
@ -216,13 +194,7 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(input, 4, x_user_md);
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
|
@ -236,24 +208,12 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md);
|
||||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md);
|
||||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||||
|
@ -291,20 +251,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||||
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
|
||||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
|
||||||
if (wReorder)
|
|
||||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
|
||||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
|
|
|
@ -76,24 +76,12 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md);
|
||||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md);
|
||||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -113,20 +101,10 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||||
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
|
||||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
|
||||||
if (wReorder)
|
|
||||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
|
||||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
mkldnnUtils::loadDataToMklStream(gradO, engine, stream, args, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST);
|
||||||
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
|
||||||
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
|
||||||
if (gradOReorder)
|
|
||||||
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
|
||||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
|
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||||
|
@ -146,8 +124,6 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
||||||
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
|
PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
|
||||||
|
|
||||||
|
|
|
@ -89,14 +89,7 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(input, 5, x_user_md);
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
|
@ -116,14 +109,7 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// output
|
// output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
|
||||||
if(output->ews() !=1 || output->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(output, 5, z_user_md);
|
||||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -140,20 +126,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||||
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
|
|
||||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
|
||||||
if (wReorder)
|
|
||||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
|
||||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
if(bias != nullptr) {
|
if(bias != nullptr) {
|
||||||
|
@ -223,14 +199,7 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(input, 5, x_user_md);
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
|
@ -245,26 +214,12 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(gradO, 5, gradO_user_md);
|
||||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(gradI, 5, gradI_user_md);
|
||||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||||
|
@ -304,20 +259,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||||
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
|
||||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
|
||||||
if (wReorder)
|
|
||||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
|
||||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
|
|
|
@ -98,13 +98,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
|
||||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(input, 4, x_user_md);
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); // do permutation NHWC -> NCHW
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
|
@ -124,13 +118,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// output
|
// output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat);
|
||||||
if(output->ews() != 1 || output->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(output, 4, z_user_md);
|
||||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); // do permutation NHWC -> NCHW
|
|
||||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -147,20 +135,10 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||||
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
|
|
||||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
|
||||||
if (wReorder)
|
|
||||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
|
||||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
if(bias != nullptr) {
|
if(bias != nullptr) {
|
||||||
|
@ -235,13 +213,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
|
||||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(input, 4, x_user_md);
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
|
@ -256,24 +228,12 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFrmat);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFrmat);
|
||||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md);
|
||||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFrmat);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFrmat);
|
||||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md);
|
||||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
// gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||||
|
@ -312,20 +272,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||||
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
|
||||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
|
||||||
if (wReorder)
|
|
||||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
|
||||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
|
|
|
@ -272,28 +272,13 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
||||||
|
|
||||||
// provide memory and check whether reorder is required
|
// provide memory and check whether reorder is required
|
||||||
// x
|
// x
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
|
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, lstm_prim_desc.src_layer_desc(), DNNL_ARG_SRC_LAYER);
|
||||||
const bool xReorder = lstm_prim_desc.src_layer_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_lstm_mem = xReorder ? dnnl::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
reorder(x_user_mem, x_lstm_mem).execute(stream, x_user_mem, x_lstm_mem);
|
|
||||||
args[DNNL_ARG_SRC_LAYER] = x_lstm_mem;
|
|
||||||
|
|
||||||
// wx
|
// wx
|
||||||
auto wx_user_mem = dnnl::memory(wx_user_md, engine, Wx->getBuffer());
|
mkldnnUtils::loadDataToMklStream(Wx, engine, stream, args, wx_user_md, lstm_prim_desc.weights_layer_desc(), DNNL_ARG_WEIGHTS_LAYER);
|
||||||
const bool wxReorder = lstm_prim_desc.weights_layer_desc()!= wx_user_mem.get_desc();
|
|
||||||
auto wx_lstm_mem = wxReorder ? dnnl::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem;
|
|
||||||
if (wxReorder)
|
|
||||||
reorder(wx_user_mem, wx_lstm_mem).execute(stream, wx_user_mem, wx_lstm_mem);
|
|
||||||
args[DNNL_ARG_WEIGHTS_LAYER] = wx_lstm_mem;
|
|
||||||
|
|
||||||
// wr
|
// wr
|
||||||
auto wr_user_mem = dnnl::memory(wr_user_md, engine, Wr->getBuffer());
|
mkldnnUtils::loadDataToMklStream(Wr, engine, stream, args, wr_user_md, lstm_prim_desc.weights_iter_desc(), DNNL_ARG_WEIGHTS_ITER);
|
||||||
const bool wrReorder = lstm_prim_desc.weights_iter_desc() != wr_user_mem.get_desc();
|
|
||||||
auto wr_lstm_mem = wxReorder ? dnnl::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem;
|
|
||||||
if (wrReorder)
|
|
||||||
reorder(wr_user_mem, wr_lstm_mem).execute(stream, wr_user_mem, wr_lstm_mem);
|
|
||||||
args[DNNL_ARG_WEIGHTS_ITER] = wr_lstm_mem;
|
|
||||||
|
|
||||||
// h
|
// h
|
||||||
auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer());
|
auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer());
|
||||||
|
@ -303,32 +288,17 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
||||||
|
|
||||||
// b
|
// b
|
||||||
if(b) {
|
if(b) {
|
||||||
auto b_user_mem = dnnl::memory(b_user_md, engine, b->getBuffer());
|
mkldnnUtils::loadDataToMklStream(b, engine, stream, args, b_user_md, lstm_prim_desc.bias_desc(), DNNL_ARG_BIAS);
|
||||||
const bool bReorder = lstm_prim_desc.bias_desc() != b_user_mem.get_desc();
|
|
||||||
auto b_lstm_mem = bReorder ? dnnl::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem;
|
|
||||||
if (bReorder)
|
|
||||||
reorder(b_user_mem, b_lstm_mem).execute(stream, b_user_mem, b_lstm_mem);
|
|
||||||
args[DNNL_ARG_BIAS] = b_lstm_mem;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// hI
|
// hI
|
||||||
if(hI) {
|
if(hI) {
|
||||||
auto hI_user_mem = dnnl::memory(hI_user_md, engine, hI->getBuffer());
|
mkldnnUtils::loadDataToMklStream(hI, engine, stream, args, hI_user_md, lstm_prim_desc.src_iter_desc(), DNNL_ARG_SRC_ITER);
|
||||||
const bool hIReorder = lstm_prim_desc.src_iter_desc() != hI_user_mem.get_desc();
|
|
||||||
auto hI_lstm_mem = hIReorder ? dnnl::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem;
|
|
||||||
if (hIReorder)
|
|
||||||
reorder(hI_user_mem, hI_lstm_mem).execute(stream, hI_user_mem, hI_lstm_mem);
|
|
||||||
args[DNNL_ARG_SRC_ITER] = hI_lstm_mem;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cI
|
// cI
|
||||||
if(cI) {
|
if(cI) {
|
||||||
auto cI_user_mem = dnnl::memory(cI_user_md, engine, cI->getBuffer());
|
mkldnnUtils::loadDataToMklStream(cI, engine, stream, args, cI_user_md, lstm_prim_desc.src_iter_c_desc(), DNNL_ARG_SRC_ITER_C);
|
||||||
const bool cIReorder = lstm_prim_desc.src_iter_c_desc() != cI_user_mem.get_desc();
|
|
||||||
auto cI_lstm_mem = cIReorder ? dnnl::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem;
|
|
||||||
if (cIReorder)
|
|
||||||
reorder(cI_user_mem, cI_lstm_mem).execute(stream, cI_user_mem, cI_lstm_mem);
|
|
||||||
args[DNNL_ARG_SRC_ITER_C] = cI_lstm_mem;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool hLReorder(false), cLReorder(false);
|
bool hLReorder(false), cLReorder(false);
|
||||||
|
|
|
@ -163,21 +163,25 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
|
mkldnnUtils::loadDataToMklStream(xTR, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
|
/*
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->getBuffer());
|
auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->getBuffer());
|
||||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
if (xReorder)
|
if (xReorder)
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||||
|
*/
|
||||||
// y
|
// y
|
||||||
|
mkldnnUtils::loadDataToMklStream(yTR, engine, stream, args, y_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||||
|
/*
|
||||||
auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->getBuffer());
|
auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->getBuffer());
|
||||||
const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc();
|
const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc();
|
||||||
auto y_mkl_mem = yReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : y_user_mem;
|
auto y_mkl_mem = yReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : y_user_mem;
|
||||||
if (yReorder)
|
if (yReorder)
|
||||||
dnnl::reorder(y_user_mem, y_mkl_mem).execute(stream, y_user_mem, y_mkl_mem);
|
dnnl::reorder(y_user_mem, y_mkl_mem).execute(stream, y_user_mem, y_mkl_mem);
|
||||||
args[DNNL_ARG_WEIGHTS] = y_mkl_mem;
|
args[DNNL_ARG_WEIGHTS] = y_mkl_mem;
|
||||||
|
*/
|
||||||
// z
|
// z
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, zR->getBuffer());
|
auto z_user_mem = dnnl::memory(z_user_md, engine, zR->getBuffer());
|
||||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||||
|
|
|
@ -28,6 +28,55 @@ using namespace dnnl;
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace mkldnnUtils {
|
namespace mkldnnUtils {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims){
|
||||||
|
|
||||||
|
std::vector<int64_t> vDims(rank);
|
||||||
|
for (auto i = 0; i < rank; i++) {
|
||||||
|
vDims[i] = array->sizeAt(i);
|
||||||
|
}
|
||||||
|
mklDims = dnnl::memory::dims(vDims);
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
dnnl::memory::format_tag getFormat(const int rank){
|
||||||
|
if (2 == rank) {
|
||||||
|
return dnnl::memory::format_tag::ab;
|
||||||
|
}
|
||||||
|
else if (3 == rank) {
|
||||||
|
return dnnl::memory::format_tag::abc;
|
||||||
|
}
|
||||||
|
else if (4 == rank) {
|
||||||
|
return dnnl::memory::format_tag::abcd;
|
||||||
|
}
|
||||||
|
else if (5 == rank) {
|
||||||
|
return dnnl::memory::format_tag::abcde;
|
||||||
|
}
|
||||||
|
else if (6 == rank) {
|
||||||
|
return dnnl::memory::format_tag::abcdef;
|
||||||
|
}
|
||||||
|
return dnnl::memory::format_tag::a; // 1 == dataSetRank
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
void setBlockStrides(const NDArray* array, const int rank, dnnl::memory::desc& mklMd){
|
||||||
|
if (array->ews() != 1 || array->ordering() != 'c') {
|
||||||
|
mklMd.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
for (auto i = 0; i < rank; ++i) {
|
||||||
|
mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
void loadDataToMklStream(const NDArray* array, dnnl::engine& engine, dnnl::stream& stream,
|
||||||
|
std::unordered_map<int, dnnl::memory>& args, dnnl::memory::desc& user_md, dnnl::memory::desc primitive_md, int DNNL_ARG ){
|
||||||
|
|
||||||
|
auto user_mem = dnnl::memory(user_md, engine, array->getBuffer());
|
||||||
|
const bool bReorder = primitive_md != user_mem.get_desc();
|
||||||
|
auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem;
|
||||||
|
if (bReorder)
|
||||||
|
dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem);
|
||||||
|
args[DNNL_ARG] = mkl_mem;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
void poolingMKLDNN(const NDArray *input, NDArray *output,
|
void poolingMKLDNN(const NDArray *input, NDArray *output,
|
||||||
const int kD, const int kH, const int kW,
|
const int kD, const int kH, const int kW,
|
||||||
|
@ -113,12 +162,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output,
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// output
|
// output
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
|
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
|
||||||
|
@ -236,12 +280,7 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
|
||||||
std::unordered_map<int, dnnl::memory> args;
|
std::unordered_map<int, dnnl::memory> args;
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
mkldnnUtils::loadDataToMklStream(gradO, engine, stream, args, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST);
|
||||||
const bool gradOReorder = op_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
|
||||||
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
|
||||||
if (gradOReorder)
|
|
||||||
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
|
||||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
|
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||||
|
@ -252,12 +291,7 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
|
||||||
if(mode == algorithm::pooling_max) {
|
if(mode == algorithm::pooling_max) {
|
||||||
|
|
||||||
// input
|
// input
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// z
|
// z
|
||||||
auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine);
|
auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine);
|
||||||
|
|
|
@ -90,6 +90,8 @@ namespace sd {
|
||||||
|
|
||||||
DECLARE_PLATFORM(softmax, ENGINE_CPU);
|
DECLARE_PLATFORM(softmax, ENGINE_CPU);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(softmax_bp, ENGINE_CPU);
|
||||||
|
|
||||||
DECLARE_PLATFORM(tanh, ENGINE_CPU);
|
DECLARE_PLATFORM(tanh, ENGINE_CPU);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -107,6 +109,41 @@ namespace sd {
|
||||||
|
|
||||||
dnnl::engine& getEngine(void* ptr);
|
dnnl::engine& getEngine(void* ptr);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This function creates memory dimentions
|
||||||
|
* @param const pointer to array
|
||||||
|
* @param const array rank
|
||||||
|
* @param reference to memory dimentions
|
||||||
|
*/
|
||||||
|
void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims);
|
||||||
|
/**
|
||||||
|
* This function generate memory format tag based on rank
|
||||||
|
* @param const array rank
|
||||||
|
* @return memory format
|
||||||
|
*/
|
||||||
|
dnnl::memory::format_tag getFormat(const int rank);
|
||||||
|
/**
|
||||||
|
* This function generate memory format tag based on rank
|
||||||
|
* @param const pointer to dataset
|
||||||
|
* @param const dataset rank
|
||||||
|
* @param reference to memory descriptor
|
||||||
|
* @return memory format
|
||||||
|
*/
|
||||||
|
void setBlockStrides(const NDArray* array, const int rank, dnnl::memory::desc& mklMd);
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
/**
|
||||||
|
* This function load and reorder user memory to mkl
|
||||||
|
* @param const pointer to dataset
|
||||||
|
* @param reference to mkl engine
|
||||||
|
* @param reference to mkl stream
|
||||||
|
* @param reference to args container for dnnl
|
||||||
|
* @param reference to user memory description
|
||||||
|
* @param primitive memory descriptor
|
||||||
|
* @param dnnl arg activation enumerator
|
||||||
|
*/
|
||||||
|
void loadDataToMklStream(const NDArray* array, dnnl::engine& engine, dnnl::stream& stream,
|
||||||
|
std::unordered_map<int, dnnl::memory>& args, dnnl::memory::desc& user_md, dnnl::memory::desc primitive_md, int DNNL_ARG);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Utility methods for MKLDNN
|
* Utility methods for MKLDNN
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -31,69 +31,36 @@ namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) {
|
static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) {
|
||||||
|
|
||||||
const auto xRank = x->rankOf();
|
const auto xRank = x->rankOf();
|
||||||
const auto zRank = z->rankOf();
|
dnnl::memory::dims xShape, zShape;
|
||||||
|
|
||||||
std::vector<int64_t> dimsX(xRank), dimsZ(zRank);
|
mkldnnUtils::getDims(x, xRank, xShape);
|
||||||
for (auto i = 0; i < xRank; i++) {
|
mkldnnUtils::getDims(z, xRank, zShape);
|
||||||
dimsX[i] = x->sizeAt(i);
|
|
||||||
dimsZ[i] = z->sizeAt(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
dnnl::memory::dims xShape = dnnl::memory::dims(dimsX);
|
|
||||||
dnnl::memory::dims zShape = dnnl::memory::dims(dimsZ);
|
|
||||||
|
|
||||||
dnnl::memory::format_tag format = dnnl::memory::format_tag::a; // 1 == xRank
|
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
|
||||||
if (2 == xRank && 1 == axis) {
|
// optimized cases
|
||||||
format = dnnl::memory::format_tag::ab;
|
if (2 == xRank && 0 == axis) {
|
||||||
}
|
|
||||||
else if (2 == xRank && 0 == axis) {
|
|
||||||
format = dnnl::memory::format_tag::ba;
|
format = dnnl::memory::format_tag::ba;
|
||||||
}
|
}
|
||||||
else if (3 == xRank) {
|
else if (4 == xRank && 1 == axis && (x->sizeAt(2) * x->sizeAt(3)) > 1) {
|
||||||
format = dnnl::memory::format_tag::abc;
|
|
||||||
}
|
|
||||||
else if (4 == xRank && 3 == axis) {
|
|
||||||
format = dnnl::memory::format_tag::abcd;
|
|
||||||
}
|
|
||||||
else if (4 == xRank && 1 == axis && dimsX[2] * dimsX[3] > 1) {
|
|
||||||
format = dnnl::memory::format_tag::acdb;
|
format = dnnl::memory::format_tag::acdb;
|
||||||
}
|
}
|
||||||
else if (4 == xRank) {
|
|
||||||
format = dnnl::memory::format_tag::abcd;
|
|
||||||
}
|
|
||||||
else if (5 == xRank) {
|
|
||||||
format = dnnl::memory::format_tag::abcde;
|
|
||||||
}
|
|
||||||
else if (6 == xRank) {
|
|
||||||
format = dnnl::memory::format_tag::abcdef;
|
|
||||||
}
|
|
||||||
|
|
||||||
dnnl::memory::data_type xType = dnnl::memory::data_type::f32;
|
dnnl::memory::data_type xType = dnnl::memory::data_type::f32;
|
||||||
dnnl::memory::data_type zType = dnnl::memory::data_type::f32;
|
|
||||||
|
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
|
||||||
|
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
|
||||||
if (x->ews() != 1 || x->ordering() != 'c') {
|
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
for (auto i = 0; i < xRank; ++i) {
|
|
||||||
x_user_md.data.format_desc.blocking.strides[i] = x->strideAt(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// z
|
// z
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format);
|
||||||
if (z->ews() != 1 || z->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(z, xRank, z_user_md);
|
||||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
for (auto i = 0; i < xRank; ++i) {
|
|
||||||
z_user_md.data.format_desc.blocking.strides[i] = z->strideAt(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -101,7 +68,6 @@ namespace sd {
|
||||||
dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
|
dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
|
||||||
|
|
||||||
// operation primitive description
|
// operation primitive description
|
||||||
// todo check this
|
|
||||||
dnnl::softmax_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis);
|
dnnl::softmax_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis);
|
||||||
|
|
||||||
dnnl::softmax_forward::primitive_desc op_prim_desc(op_desc, attr, engine);
|
dnnl::softmax_forward::primitive_desc op_prim_desc(op_desc, attr, engine);
|
||||||
|
@ -114,12 +80,7 @@ namespace sd {
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
|
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// z
|
// z
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
|
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
|
||||||
|
@ -178,6 +139,136 @@ namespace sd {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
static void softmaxBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx, const int axis) {
|
||||||
|
|
||||||
|
const auto xRank = x->rankOf();
|
||||||
|
const auto dLdzRank = dLdz->rankOf();
|
||||||
|
|
||||||
|
dnnl::memory::dims xShape, dLdxShape, dLdzShape;
|
||||||
|
|
||||||
|
mkldnnUtils::getDims(x, xRank, xShape);
|
||||||
|
mkldnnUtils::getDims(dLdx, xRank, dLdxShape);
|
||||||
|
mkldnnUtils::getDims(dLdz, dLdzRank, dLdzShape);
|
||||||
|
|
||||||
|
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
|
||||||
|
|
||||||
|
// x
|
||||||
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||||
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||||
|
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
|
||||||
|
|
||||||
|
// dLdx
|
||||||
|
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format);
|
||||||
|
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format);
|
||||||
|
mkldnnUtils::setBlockStrides(dLdx, xRank, dLdx_user_md);
|
||||||
|
// todo if mkl does not support broadcast we can remove this
|
||||||
|
format = mkldnnUtils::getFormat(dLdzRank);
|
||||||
|
|
||||||
|
// dLdz
|
||||||
|
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format);
|
||||||
|
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format);
|
||||||
|
mkldnnUtils::setBlockStrides(dLdz, dLdzRank, dLdz_user_md);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// operation primitive description
|
||||||
|
// forward description
|
||||||
|
dnnl::softmax_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis);
|
||||||
|
dnnl::softmax_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||||
|
|
||||||
|
// backward description
|
||||||
|
dnnl::softmax_backward::desc op_bp_desc(dLdz_mkl_md, dLdx_mkl_md, axis);
|
||||||
|
dnnl::softmax_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, dnnl::memory> argsbp, argsff;
|
||||||
|
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required for forward
|
||||||
|
// input
|
||||||
|
mkldnnUtils::loadDataToMklStream(x, engine, stream, argsff, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
|
|
||||||
|
// dLdx
|
||||||
|
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer());
|
||||||
|
const bool dLdxReorder = op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc();
|
||||||
|
auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : dLdx_user_mem;
|
||||||
|
argsff[DNNL_ARG_DST] = dLdx_mkl_mem;
|
||||||
|
|
||||||
|
// check and arg set for backprob
|
||||||
|
argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
|
||||||
|
argsbp[DNNL_ARG_DST] = dLdx_mkl_mem;
|
||||||
|
// dLdz
|
||||||
|
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, argsbp, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST);
|
||||||
|
|
||||||
|
// run calculations forward
|
||||||
|
dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff);
|
||||||
|
|
||||||
|
// run calculations backward
|
||||||
|
dnnl::softmax_backward(op_bp_prim_desc).execute(stream, argsbp);
|
||||||
|
|
||||||
|
// reorder outputs if necessary
|
||||||
|
if (dLdxReorder)
|
||||||
|
dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PLATFORM_IMPL(softmax_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto dLdz = INPUT_VARIABLE(1);
|
||||||
|
auto dLdx = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const int rank = input->rankOf();
|
||||||
|
const int dLdzRank = dLdz->rankOf();
|
||||||
|
int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1;
|
||||||
|
|
||||||
|
if (dim < 0) {
|
||||||
|
dim += rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
REQUIRE_TRUE(dim < rank && dim >= 0, 0, "SOFTMAX_MKLDNN_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(rank <= 6 && dLdzRank <= 6, 0, "SOFTMAX_MKLDNN_BP OP: the rank of input and dLdz must be less or qual 6, but got input rank = %i and dLdz rank rank = %i instead !", rank, dLdzRank);
|
||||||
|
|
||||||
|
// mkldnnSoftMax
|
||||||
|
softmaxBpMKLDNN(input, dLdz, dLdx, dim);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(softmax_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto dLdz = INPUT_VARIABLE(1);
|
||||||
|
auto dLdx = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const DataType xType = x->dataType();
|
||||||
|
const DataType dLdzType = dLdz->dataType();
|
||||||
|
const DataType dLdxType = dLdx->dataType();
|
||||||
|
|
||||||
|
const int xRank = x->rankOf();
|
||||||
|
const int dLdzRank = dLdz->rankOf();
|
||||||
|
|
||||||
|
bool bSupportedRanks = xRank < 7 && dLdzRank == xRank && (!x->isEmpty() && !dLdz->isEmpty());
|
||||||
|
|
||||||
|
if (bSupportedRanks) {
|
||||||
|
for (int i = 0; i < xRank; i++) {
|
||||||
|
if (x->sizeAt(i) != dLdz->sizeAt(i)) {
|
||||||
|
bSupportedRanks = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Source Destination
|
||||||
|
//f32 f32
|
||||||
|
return block.isUseMKLDNN() && bSupportedRanks && (xType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,52 +35,21 @@ namespace sd {
|
||||||
static void tanhMKLDNN(const NDArray* x, NDArray* z) {
|
static void tanhMKLDNN(const NDArray* x, NDArray* z) {
|
||||||
|
|
||||||
const auto xRank = x->rankOf();
|
const auto xRank = x->rankOf();
|
||||||
|
dnnl::memory::dims xShape, zShape;
|
||||||
|
|
||||||
std::vector<int64_t> dimsX(xRank), dimsZ(xRank);
|
mkldnnUtils::getDims(x, xRank, xShape);
|
||||||
for (auto i = 0; i < xRank; i++) {
|
mkldnnUtils::getDims(z, xRank, zShape);
|
||||||
dimsX[i] = x->sizeAt(i);
|
|
||||||
dimsZ[i] = z->sizeAt(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
dnnl::memory::dims xShape = dnnl::memory::dims(dimsX);
|
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
|
||||||
dnnl::memory::dims zShape = dnnl::memory::dims(dimsZ);
|
|
||||||
|
|
||||||
dnnl::memory::format_tag format = dnnl::memory::format_tag::a;
|
|
||||||
if (2 == xRank) {
|
|
||||||
format = dnnl::memory::format_tag::ab;
|
|
||||||
}
|
|
||||||
else if (3 == xRank) {
|
|
||||||
format = dnnl::memory::format_tag::abc;
|
|
||||||
}
|
|
||||||
else if (4 == xRank) {
|
|
||||||
format = dnnl::memory::format_tag::abcd;
|
|
||||||
}
|
|
||||||
else if (5 == xRank) {
|
|
||||||
format = dnnl::memory::format_tag::abcde;
|
|
||||||
}
|
|
||||||
else if (6 == xRank) {
|
|
||||||
format = dnnl::memory::format_tag::abcdef;
|
|
||||||
}
|
|
||||||
|
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||||
|
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
|
||||||
if (x->ews() != 1 || x->ordering() != 'c') {
|
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
for (auto i = 0; i < xRank; ++i) {
|
|
||||||
x_user_md.data.format_desc.blocking.strides[i] = x->strideAt(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// z
|
// z
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
|
||||||
if (z->ews() != 1 || z->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(z, xRank, z_user_md);
|
||||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
for (auto i = 0; i < xRank; ++i) {
|
|
||||||
z_user_md.data.format_desc.blocking.strides[i] = z->strideAt(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -99,12 +68,7 @@ namespace sd {
|
||||||
|
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
// input
|
// input
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
|
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
|
|
||||||
// z
|
// z
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
|
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
|
||||||
|
|
|
@ -76,3 +76,63 @@ TEST_F(DeclarableOpsTests18, test_tanh_2) {
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST) {
|
||||||
|
|
||||||
|
NDArray input('c', { 2, 2 }, { 1,2,3,4 }, DataType::FLOAT32);
|
||||||
|
NDArray epsilon('c', { 2, 2 }, { .1, .2, .3, .4 }, DataType::FLOAT32);
|
||||||
|
|
||||||
|
int axis = 1;
|
||||||
|
|
||||||
|
NDArray output('c', { 2, 2 }, DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray exp('c', { 2, 2 }, { -0.019661, 0.019661, -0.019661, 0.019661 }, DataType::FLOAT32);
|
||||||
|
|
||||||
|
sd::ops::softmax_bp op;
|
||||||
|
|
||||||
|
Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis });
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(output.equalsTo(exp));
|
||||||
|
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST2) {
|
||||||
|
|
||||||
|
NDArray input('c', { 4, 5, 2, 3 }, DataType::FLOAT32);
|
||||||
|
NDArray epsilon('c', { 4, 5, 2, 3 }, { -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855 }, DataType::FLOAT32);
|
||||||
|
input.linspace(0.1, 0.2);
|
||||||
|
|
||||||
|
int axis = -1;
|
||||||
|
|
||||||
|
NDArray output('c', { 4, 5, 2, 3 }, DataType::FLOAT32);
|
||||||
|
NDArray exp('c', { 4, 5, 2, 3 }, { -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253 }, DataType::FLOAT32);
|
||||||
|
|
||||||
|
sd::ops::softmax_bp op;
|
||||||
|
|
||||||
|
Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis });
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(output.equalsTo(exp));
|
||||||
|
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST3) {
|
||||||
|
|
||||||
|
NDArray input('f', { 4, 5, 2, 3 }, DataType::FLOAT32);
|
||||||
|
NDArray epsilon('f', { 4, 5, 2, 3 }, { -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855 }, DataType::FLOAT32);
|
||||||
|
input.linspace(-5., 0.5);
|
||||||
|
|
||||||
|
int axis = 1;
|
||||||
|
|
||||||
|
NDArray output('f', { 4, 5, 2, 3 }, DataType::FLOAT32);
|
||||||
|
NDArray expC('c', { 4, 5, 2, 3 }, { -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909, -0.000000, 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, -0.000149, 0.000054, 0.000095, 0.000095, -0.000149, 0.000054, -0.001760, 0.002943, -0.001183, -0.001183, -0.001760, 0.002943, 0.001909, -0.002997, 0.001088, 0.001088, 0.001909, -0.002997, 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, 0.000000, 0.000054, 0.000095, -0.000149, -0.000149, 0.000054, 0.000095, 0.002943, -0.001183, -0.001760, -0.001760, 0.002943, -0.001183, -0.002997, 0.001088, 0.001909, 0.001909, -0.002997, 0.001088, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909 }, DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray exp('f', { 4, 5, 2, 3 }, DataType::FLOAT32);
|
||||||
|
exp.assign(expC);
|
||||||
|
|
||||||
|
sd::ops::softmax_bp op;
|
||||||
|
|
||||||
|
Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis });
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(output.equalsTo(exp));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
|
@ -71,8 +71,10 @@ TEST_F(MklDnnTests, helpers_includer) {
|
||||||
|
|
||||||
sd::ops::platforms::PLATFORM_softmax_ENGINE_CPU softmax;
|
sd::ops::platforms::PLATFORM_softmax_ENGINE_CPU softmax;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_softmax_bp_ENGINE_CPU softmax_bp;
|
||||||
|
|
||||||
sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh;
|
sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh;
|
||||||
|
|
||||||
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &tanh });
|
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh });
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
Loading…
Reference in New Issue