diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index 42b7d231c..cc52e90b3 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -90,32 +90,13 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray // x dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); - if(x->ews() != 1 || x->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = x->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = x->strideAt(1); - if(xRank > 2) { - x_user_md.data.format_desc.blocking.strides[2] = x->strideAt(2); - x_user_md.data.format_desc.blocking.strides[3] = x->strideAt(3); - } - if(xRank > 4) - x_user_md.data.format_desc.blocking.strides[4] = x->strideAt(4); - } + mkldnnUtils::setBlockStrides(x, xRank, x_user_md); // z, output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format); - if(z->ews() != 1 || z->ordering() != 'c') { - z_user_md.data.format_kind = dnnl_blocked; // overrides format - z_user_md.data.format_desc.blocking.strides[0] = z->strideAt(0); - z_user_md.data.format_desc.blocking.strides[1] = z->strideAt(1); - if(xRank > 2) { - z_user_md.data.format_desc.blocking.strides[2] = z->strideAt(2); - z_user_md.data.format_desc.blocking.strides[3] = z->strideAt(3); - } - if(xRank > 4) - z_user_md.data.format_desc.blocking.strides[4] = z->strideAt(4); - } + + mkldnnUtils::setBlockStrides(z, xRank, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -131,14 +112,9 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray // provide memory and check whether reorder is required // x - auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer()); - const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; - - // z + mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC); + + // z auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc(); auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem; @@ -230,47 +206,20 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const // x dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); - if(x->ews() != 1 || x->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = x->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = x->strideAt(1); - if(xRank > 2) { - x_user_md.data.format_desc.blocking.strides[2] = x->strideAt(2); - x_user_md.data.format_desc.blocking.strides[3] = x->strideAt(3); - } - if(xRank > 4) - x_user_md.data.format_desc.blocking.strides[4] = x->strideAt(4); - } + mkldnnUtils::setBlockStrides(x, xRank, x_user_md); + // dLdO dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); - if(dLdO->ews() != 1 || dLdO->ordering() != 'c') { - dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format - dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->strideAt(0); - dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->strideAt(1); - if(xRank > 2) { - dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->strideAt(2); - dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->strideAt(3); - } - if(xRank > 4) - dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->strideAt(4); - } + + mkldnnUtils::setBlockStrides(dLdO, xRank, dLdO_user_md); // dLdI dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format); - if(dLdI->ews() != 1 || dLdI->ordering() != 'c') { - dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format - dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->strideAt(0); - dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->strideAt(1); - if(xRank > 2) { - dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->strideAt(2); - dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->strideAt(3); - } - if(xRank > 4) - dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->strideAt(4); - } + + mkldnnUtils::setBlockStrides(dLdI, xRank, dLdI_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -290,20 +239,10 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const // provide memory and check whether reorder is required // x - auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer()); - const bool xReorder = op_bp_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; + mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_bp_prim_desc.src_desc(), DNNL_ARG_SRC); // dLdO - auto dLdO_user_mem = dnnl::memory(dLdO_user_md, engine, dLdO->getBuffer()); - const bool dLdOReorder = op_bp_prim_desc.diff_dst_desc() != dLdO_user_mem.get_desc(); - auto dLdO_mkl_mem = dLdOReorder ? dnnl::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdO_user_mem; - if (dLdOReorder) - dnnl::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem); - args[DNNL_ARG_DIFF_DST] = dLdO_mkl_mem; + mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, args, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); // mean auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index fd34368a6..9d236d293 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -67,13 +67,7 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - if(input->ews() != 1 || input->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); - } + mkldnnUtils::setBlockStrides(input, 4, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); @@ -92,13 +86,8 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); - if(output->ews() != 1 || output->ordering() != 'c') { - z_user_md.data.format_kind = dnnl_blocked; // overrides format - z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); - z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); - z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2); - z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3); - } + + mkldnnUtils::setBlockStrides(output, 4, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -114,20 +103,10 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, // provide memory buffers and check whether reorder is required // input - auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); - const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; + mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); // weights - auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); - const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem; - if (wReorder) - dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); // bias if(bias != nullptr) { @@ -185,13 +164,7 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - if(input->ews() != 1 || input->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); - } + mkldnnUtils::setBlockStrides(input, 4, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); @@ -205,25 +178,13 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); - if(gradO->ews() != 1 || gradO->ordering() != 'c') { - gradO_user_md.data.format_kind = dnnl_blocked; // overrides format - gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); - gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1); - gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2); - gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3); - } - + mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); + // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - if(gradI->ews() != 1 || gradI->ordering() != 'c') { - gradI_user_md.data.format_kind = dnnl_blocked; // overrides format - gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); - gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1); - gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2); - gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3); - } - + mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); + // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat); @@ -260,20 +221,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N // provide memory buffers and check whether reorder is required // input - auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); - const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; + mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); // weights - auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); - const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; - if (wReorder) - dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp index 3003713e3..6c0575378 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp @@ -70,14 +70,7 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - if(input->ews() != 1 || input->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); - x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4); - } + mkldnnUtils::setBlockStrides(input, 5, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); @@ -97,14 +90,7 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); - if(output->ews() != 1 || output->ordering() != 'c') { - z_user_md.data.format_kind = dnnl_blocked; // overrides format - z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); - z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); - z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2); - z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3); - z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(4); - } + mkldnnUtils::setBlockStrides(output, 5, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -120,21 +106,11 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, // provide memory buffers and check whether reorder is required // input - auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); - const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; + mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); // weights - auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); - const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem; - if (wReorder) - dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; - + mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); + // bias if(bias != nullptr) { auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer()); @@ -194,14 +170,7 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - if(input->ews() != 1 || input->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); - x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4); - } + mkldnnUtils::setBlockStrides(input, 5, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); @@ -216,26 +185,14 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); - if(gradO->ews() != 1 || gradO->ordering() != 'c') { - gradO_user_md.data.format_kind = dnnl_blocked; // overrides format - gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); - gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1); - gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2); - gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3); - gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(4); - } + + mkldnnUtils::setBlockStrides(gradO, 5, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - if(gradI->ews() != 1 || gradI->ordering() != 'c') { - gradI_user_md.data.format_kind = dnnl_blocked; // overrides format - gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); - gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1); - gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2); - gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3); - gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(4); - } + + mkldnnUtils::setBlockStrides(gradI, 5, gradI_user_md); // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); @@ -274,20 +231,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N // provide memory buffers and check whether reorder is required // input - auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); - const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; + mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); // weights - auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); - const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; - if (wReorder) - dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp index 9a2051232..1ee177e6a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp @@ -88,13 +88,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); - if(input->ews() != 1 || input->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); - } + mkldnnUtils::setBlockStrides(input, 4, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); @@ -113,13 +107,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat); - if(output->ews() != 1 || output->ordering() != 'c') { - z_user_md.data.format_kind = dnnl_blocked; // overrides format - z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); - z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); - z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2); - z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3); - } + mkldnnUtils::setBlockStrides(output, 4, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -136,20 +124,10 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N // provide memory buffers and check whether reorder is required // input - auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); - const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; + mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); // weights - auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); - const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem; - if (wReorder) - dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); // bias if(bias != nullptr) { @@ -216,13 +194,7 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); - if(input->ews() != 1 || input->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); - } + mkldnnUtils::setBlockStrides(input, 4, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); @@ -236,24 +208,12 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); - if(gradO->ews() != 1 || gradO->ordering() != 'c') { - gradO_user_md.data.format_kind = dnnl_blocked; // overrides format - gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); - gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1); - gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2); - gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3); - } + mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); - if(gradI->ews() != 1 || gradI->ordering() != 'c') { - gradI_user_md.data.format_kind = dnnl_blocked; // overrides format - gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); - gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1); - gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2); - gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3); - } + mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); @@ -291,20 +251,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const // provide memory buffers and check whether reorder is required // input - auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); - const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; + mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); // weights - auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); - const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; - if (wReorder) - dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp index 2e210d0f4..e7283e1d3 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp @@ -76,24 +76,12 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); - if(gradO->ews() != 1 || gradO->ordering() != 'c') { - gradO_user_md.data.format_kind = dnnl_blocked; // overrides format - gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); - gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1); - gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2); - gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3); - } + mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); - if(gradI->ews() != 1 || gradI->ordering() != 'c') { - gradI_user_md.data.format_kind = dnnl_blocked; // overrides format - gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); - gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1); - gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2); - gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3); - } + mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -113,20 +101,10 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad // provide memory buffers and check whether reorder is required // weights - auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); - const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; - if (wReorder) - dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); // gradO - auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); - const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorder) - dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem; + mkldnnUtils::loadDataToMklStream(gradO, engine, stream, args, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); // gradI auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); @@ -146,8 +124,6 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad // shape::printArray(z_mkl_mem.map_data(),8); } - - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp index 50e766d3b..dc50288a0 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp @@ -89,14 +89,7 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); - if(input->ews() != 1 || input->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); - x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4); - } + mkldnnUtils::setBlockStrides(input, 5, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); @@ -116,14 +109,7 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat); - if(output->ews() !=1 || output->ordering() != 'c') { - z_user_md.data.format_kind = dnnl_blocked; // overrides format - z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); - z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); - z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2); - z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3); - z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(4); - } + mkldnnUtils::setBlockStrides(output, 5, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -140,20 +126,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N // provide memory buffers and check whether reorder is required // input - auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); - const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; + mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); // weights - auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); - const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem; - if (wReorder) - dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); // bias if(bias != nullptr) { @@ -223,14 +199,7 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); - if(input->ews() != 1 || input->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); - x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4); - } + mkldnnUtils::setBlockStrides(input, 5, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); @@ -245,26 +214,12 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); - if(gradO->ews() != 1 || gradO->ordering() != 'c') { - gradO_user_md.data.format_kind = dnnl_blocked; // overrides format - gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); - gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1); - gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2); - gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3); - gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(4); - } + mkldnnUtils::setBlockStrides(gradO, 5, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); - if(gradI->ews() != 1 || gradI->ordering() != 'c') { - gradI_user_md.data.format_kind = dnnl_blocked; // overrides format - gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); - gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1); - gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2); - gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3); - gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(4); - } + mkldnnUtils::setBlockStrides(gradI, 5, gradI_user_md); // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat); @@ -304,20 +259,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, // provide memory buffers and check whether reorder is required // input - auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); - const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; + mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); // weights - auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); - const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; - if (wReorder) - dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp index db0a1979c..ae4409923 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp @@ -98,13 +98,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat); - if(input->ews() != 1 || input->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); // do permutation NHWC -> NCHW - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); - } + mkldnnUtils::setBlockStrides(input, 4, x_user_md); // weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); @@ -124,13 +118,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat); - if(output->ews() != 1 || output->ordering() != 'c') { - z_user_md.data.format_kind = dnnl_blocked; // overrides format - z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); - z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); // do permutation NHWC -> NCHW - z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2); - z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3); - } + mkldnnUtils::setBlockStrides(output, 4, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -147,20 +135,10 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, // provide memory buffers and check whether reorder is required // input - auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); - const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; + mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); // weights - auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); - const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem; - if (wReorder) - dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); // bias if(bias != nullptr) { @@ -235,13 +213,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat); - if(input->ews() != 1 || input->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); - } + mkldnnUtils::setBlockStrides(input, 4, x_user_md); // weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); @@ -256,24 +228,12 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFrmat); - if(gradO->ews() != 1 || gradO->ordering() != 'c') { - gradO_user_md.data.format_kind = dnnl_blocked; // overrides format - gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); - gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1); - gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2); - gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3); - } + mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFrmat); - if(gradI->ews() != 1 || gradI->ordering() != 'c') { - gradI_user_md.data.format_kind = dnnl_blocked; // overrides format - gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); - gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1); - gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2); - gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3); - } + mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); // gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); @@ -312,20 +272,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w // provide memory buffers and check whether reorder is required // input - auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); - const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; + mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); // weights - auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); - const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; - if (wReorder) - dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index ad612435d..c4d987054 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -272,29 +272,14 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* // provide memory and check whether reorder is required // x - auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer()); - const bool xReorder = lstm_prim_desc.src_layer_desc() != x_user_mem.get_desc(); - auto x_lstm_mem = xReorder ? dnnl::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem; - if (xReorder) - reorder(x_user_mem, x_lstm_mem).execute(stream, x_user_mem, x_lstm_mem); - args[DNNL_ARG_SRC_LAYER] = x_lstm_mem; - + mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, lstm_prim_desc.src_layer_desc(), DNNL_ARG_SRC_LAYER); + // wx - auto wx_user_mem = dnnl::memory(wx_user_md, engine, Wx->getBuffer()); - const bool wxReorder = lstm_prim_desc.weights_layer_desc()!= wx_user_mem.get_desc(); - auto wx_lstm_mem = wxReorder ? dnnl::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem; - if (wxReorder) - reorder(wx_user_mem, wx_lstm_mem).execute(stream, wx_user_mem, wx_lstm_mem); - args[DNNL_ARG_WEIGHTS_LAYER] = wx_lstm_mem; + mkldnnUtils::loadDataToMklStream(Wx, engine, stream, args, wx_user_md, lstm_prim_desc.weights_layer_desc(), DNNL_ARG_WEIGHTS_LAYER); // wr - auto wr_user_mem = dnnl::memory(wr_user_md, engine, Wr->getBuffer()); - const bool wrReorder = lstm_prim_desc.weights_iter_desc() != wr_user_mem.get_desc(); - auto wr_lstm_mem = wxReorder ? dnnl::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem; - if (wrReorder) - reorder(wr_user_mem, wr_lstm_mem).execute(stream, wr_user_mem, wr_lstm_mem); - args[DNNL_ARG_WEIGHTS_ITER] = wr_lstm_mem; - + mkldnnUtils::loadDataToMklStream(Wr, engine, stream, args, wr_user_md, lstm_prim_desc.weights_iter_desc(), DNNL_ARG_WEIGHTS_ITER); + // h auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer()); const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc(); @@ -303,32 +288,17 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* // b if(b) { - auto b_user_mem = dnnl::memory(b_user_md, engine, b->getBuffer()); - const bool bReorder = lstm_prim_desc.bias_desc() != b_user_mem.get_desc(); - auto b_lstm_mem = bReorder ? dnnl::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem; - if (bReorder) - reorder(b_user_mem, b_lstm_mem).execute(stream, b_user_mem, b_lstm_mem); - args[DNNL_ARG_BIAS] = b_lstm_mem; + mkldnnUtils::loadDataToMklStream(b, engine, stream, args, b_user_md, lstm_prim_desc.bias_desc(), DNNL_ARG_BIAS); } // hI if(hI) { - auto hI_user_mem = dnnl::memory(hI_user_md, engine, hI->getBuffer()); - const bool hIReorder = lstm_prim_desc.src_iter_desc() != hI_user_mem.get_desc(); - auto hI_lstm_mem = hIReorder ? dnnl::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem; - if (hIReorder) - reorder(hI_user_mem, hI_lstm_mem).execute(stream, hI_user_mem, hI_lstm_mem); - args[DNNL_ARG_SRC_ITER] = hI_lstm_mem; + mkldnnUtils::loadDataToMklStream(hI, engine, stream, args, hI_user_md, lstm_prim_desc.src_iter_desc(), DNNL_ARG_SRC_ITER); } // cI if(cI) { - auto cI_user_mem = dnnl::memory(cI_user_md, engine, cI->getBuffer()); - const bool cIReorder = lstm_prim_desc.src_iter_c_desc() != cI_user_mem.get_desc(); - auto cI_lstm_mem = cIReorder ? dnnl::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem; - if (cIReorder) - reorder(cI_user_mem, cI_lstm_mem).execute(stream, cI_user_mem, cI_lstm_mem); - args[DNNL_ARG_SRC_ITER_C] = cI_lstm_mem; + mkldnnUtils::loadDataToMklStream(cI, engine, stream, args, cI_user_md, lstm_prim_desc.src_iter_c_desc(), DNNL_ARG_SRC_ITER_C); } bool hLReorder(false), cLReorder(false); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp index 7345b6543..805507277 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp @@ -163,21 +163,25 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b // provide memory buffers and check whether reorder is required // input + mkldnnUtils::loadDataToMklStream(xTR, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); + /* auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->getBuffer()); const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; if (xReorder) dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); args[DNNL_ARG_SRC] = x_mkl_mem; - +*/ // y + mkldnnUtils::loadDataToMklStream(yTR, engine, stream, args, y_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); + /* auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->getBuffer()); const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc(); auto y_mkl_mem = yReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : y_user_mem; if (yReorder) dnnl::reorder(y_user_mem, y_mkl_mem).execute(stream, y_user_mem, y_mkl_mem); args[DNNL_ARG_WEIGHTS] = y_mkl_mem; - +*/ // z auto z_user_mem = dnnl::memory(z_user_md, engine, zR->getBuffer()); const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index 6cb74d628..1c6974ea8 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -28,6 +28,55 @@ using namespace dnnl; namespace sd { namespace mkldnnUtils { +////////////////////////////////////////////////////////////////////// +void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims){ + + std::vector vDims(rank); + for (auto i = 0; i < rank; i++) { + vDims[i] = array->sizeAt(i); + } + mklDims = dnnl::memory::dims(vDims); +} +////////////////////////////////////////////////////////////////////// +dnnl::memory::format_tag getFormat(const int rank){ + if (2 == rank) { + return dnnl::memory::format_tag::ab; + } + else if (3 == rank) { + return dnnl::memory::format_tag::abc; + } + else if (4 == rank) { + return dnnl::memory::format_tag::abcd; + } + else if (5 == rank) { + return dnnl::memory::format_tag::abcde; + } + else if (6 == rank) { + return dnnl::memory::format_tag::abcdef; + } + return dnnl::memory::format_tag::a; // 1 == dataSetRank +} +////////////////////////////////////////////////////////////////////// +void setBlockStrides(const NDArray* array, const int rank, dnnl::memory::desc& mklMd){ + if (array->ews() != 1 || array->ordering() != 'c') { + mklMd.data.format_kind = dnnl_blocked; // overrides format + for (auto i = 0; i < rank; ++i) { + mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i); + } + } +} +//////////////////////////////////////////////////////////////////////////////////////////////// +void loadDataToMklStream(const NDArray* array, dnnl::engine& engine, dnnl::stream& stream, + std::unordered_map& args, dnnl::memory::desc& user_md, dnnl::memory::desc primitive_md, int DNNL_ARG ){ + + auto user_mem = dnnl::memory(user_md, engine, array->getBuffer()); + const bool bReorder = primitive_md != user_mem.get_desc(); + auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem; + if (bReorder) + dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem); + args[DNNL_ARG] = mkl_mem; +} + ////////////////////////////////////////////////////////////////////// void poolingMKLDNN(const NDArray *input, NDArray *output, const int kD, const int kH, const int kW, @@ -113,12 +162,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output, // provide memory buffers and check whether reorder is required // input - auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); - const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; + mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); // output auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer()); @@ -236,13 +280,8 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, std::unordered_map args; // gradO - auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); - const bool gradOReorder = op_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorder) - dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem; - + mkldnnUtils::loadDataToMklStream(gradO, engine, stream, args, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); + // gradI auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); @@ -252,13 +291,8 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, if(mode == algorithm::pooling_max) { // input - auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); - const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; - + mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC); + // z auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine); args[DNNL_ARG_DST] = z_mkl_mem; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index 693e515b1..1237baac0 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -90,6 +90,8 @@ namespace sd { DECLARE_PLATFORM(softmax, ENGINE_CPU); + DECLARE_PLATFORM(softmax_bp, ENGINE_CPU); + DECLARE_PLATFORM(tanh, ENGINE_CPU); } @@ -107,6 +109,41 @@ namespace sd { dnnl::engine& getEngine(void* ptr); + /** + * This function creates memory dimentions + * @param const pointer to array + * @param const array rank + * @param reference to memory dimentions + */ + void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims); + /** + * This function generate memory format tag based on rank + * @param const array rank + * @return memory format + */ + dnnl::memory::format_tag getFormat(const int rank); + /** + * This function generate memory format tag based on rank + * @param const pointer to dataset + * @param const dataset rank + * @param reference to memory descriptor + * @return memory format + */ + void setBlockStrides(const NDArray* array, const int rank, dnnl::memory::desc& mklMd); + ////////////////////////////////////////////////////////////////////// + /** + * This function load and reorder user memory to mkl + * @param const pointer to dataset + * @param reference to mkl engine + * @param reference to mkl stream + * @param reference to args container for dnnl + * @param reference to user memory description + * @param primitive memory descriptor + * @param dnnl arg activation enumerator + */ + void loadDataToMklStream(const NDArray* array, dnnl::engine& engine, dnnl::stream& stream, + std::unordered_map& args, dnnl::memory::desc& user_md, dnnl::memory::desc primitive_md, int DNNL_ARG); + /** * Utility methods for MKLDNN */ diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp index 924693f85..d67d205da 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp @@ -31,69 +31,36 @@ namespace sd { namespace ops { namespace platforms { + ////////////////////////////////////////////////////////////////////// static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) { const auto xRank = x->rankOf(); - const auto zRank = z->rankOf(); + dnnl::memory::dims xShape, zShape; - std::vector dimsX(xRank), dimsZ(zRank); - for (auto i = 0; i < xRank; i++) { - dimsX[i] = x->sizeAt(i); - dimsZ[i] = z->sizeAt(i); - } + mkldnnUtils::getDims(x, xRank, xShape); + mkldnnUtils::getDims(z, xRank, zShape); - dnnl::memory::dims xShape = dnnl::memory::dims(dimsX); - dnnl::memory::dims zShape = dnnl::memory::dims(dimsZ); - dnnl::memory::format_tag format = dnnl::memory::format_tag::a; // 1 == xRank - if (2 == xRank && 1 == axis) { - format = dnnl::memory::format_tag::ab; - } - else if (2 == xRank && 0 == axis) { + dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); + // optimized cases + if (2 == xRank && 0 == axis) { format = dnnl::memory::format_tag::ba; } - else if (3 == xRank) { - format = dnnl::memory::format_tag::abc; - } - else if (4 == xRank && 3 == axis) { - format = dnnl::memory::format_tag::abcd; - } - else if (4 == xRank && 1 == axis && dimsX[2] * dimsX[3] > 1) { + else if (4 == xRank && 1 == axis && (x->sizeAt(2) * x->sizeAt(3)) > 1) { format = dnnl::memory::format_tag::acdb; } - else if (4 == xRank) { - format = dnnl::memory::format_tag::abcd; - } - else if (5 == xRank) { - format = dnnl::memory::format_tag::abcde; - } - else if (6 == xRank) { - format = dnnl::memory::format_tag::abcdef; - } dnnl::memory::data_type xType = dnnl::memory::data_type::f32; - dnnl::memory::data_type zType = dnnl::memory::data_type::f32; dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); - - if (x->ews() != 1 || x->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - for (auto i = 0; i < xRank; ++i) { - x_user_md.data.format_desc.blocking.strides[i] = x->strideAt(i); - } - } + mkldnnUtils::setBlockStrides(x, xRank, x_user_md); // z - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format); - if (z->ews() != 1 || z->ordering() != 'c') { - z_user_md.data.format_kind = dnnl_blocked; // overrides format - for (auto i = 0; i < xRank; ++i) { - z_user_md.data.format_desc.blocking.strides[i] = z->strideAt(i); - } - } + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format); + mkldnnUtils::setBlockStrides(z, xRank, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -101,7 +68,6 @@ namespace sd { dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0) // operation primitive description - // todo check this dnnl::softmax_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis); dnnl::softmax_forward::primitive_desc op_prim_desc(op_desc, attr, engine); @@ -114,12 +80,7 @@ namespace sd { // provide memory buffers and check whether reorder is required // input - auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer()); - const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; + mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); // z auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); @@ -178,6 +139,136 @@ namespace sd { } + ////////////////////////////////////////////////////////////////////// + static void softmaxBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx, const int axis) { + + const auto xRank = x->rankOf(); + const auto dLdzRank = dLdz->rankOf(); + + dnnl::memory::dims xShape, dLdxShape, dLdzShape; + + mkldnnUtils::getDims(x, xRank, xShape); + mkldnnUtils::getDims(dLdx, xRank, dLdxShape); + mkldnnUtils::getDims(dLdz, dLdzRank, dLdzShape); + + dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); + + // x + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + mkldnnUtils::setBlockStrides(x, xRank, x_user_md); + + // dLdx + dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); + dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); + mkldnnUtils::setBlockStrides(dLdx, xRank, dLdx_user_md); + // todo if mkl does not support broadcast we can remove this + format = mkldnnUtils::getFormat(dLdzRank); + + // dLdz + dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); + dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); + mkldnnUtils::setBlockStrides(dLdz, dLdzRank, dLdz_user_md); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + // forward description + dnnl::softmax_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis); + dnnl::softmax_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward description + dnnl::softmax_backward::desc op_bp_desc(dLdz_mkl_md, dLdx_mkl_md, axis); + dnnl::softmax_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map argsbp, argsff; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required for forward + // input + mkldnnUtils::loadDataToMklStream(x, engine, stream, argsff, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC); + + // dLdx + auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer()); + const bool dLdxReorder = op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc(); + auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : dLdx_user_mem; + argsff[DNNL_ARG_DST] = dLdx_mkl_mem; + + // check and arg set for backprob + argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; + argsbp[DNNL_ARG_DST] = dLdx_mkl_mem; + // dLdz + mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, argsbp, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); + + // run calculations forward + dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff); + + // run calculations backward + dnnl::softmax_backward(op_bp_prim_desc).execute(stream, argsbp); + + // reorder outputs if necessary + if (dLdxReorder) + dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem); + + stream.wait(); + } + + + PLATFORM_IMPL(softmax_bp, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); + auto dLdz = INPUT_VARIABLE(1); + auto dLdx = OUTPUT_VARIABLE(0); + + const int rank = input->rankOf(); + const int dLdzRank = dLdz->rankOf(); + int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; + + if (dim < 0) { + dim += rank; + } + + REQUIRE_TRUE(dim < rank && dim >= 0, 0, "SOFTMAX_MKLDNN_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); + + REQUIRE_TRUE(rank <= 6 && dLdzRank <= 6, 0, "SOFTMAX_MKLDNN_BP OP: the rank of input and dLdz must be less or qual 6, but got input rank = %i and dLdz rank rank = %i instead !", rank, dLdzRank); + + // mkldnnSoftMax + softmaxBpMKLDNN(input, dLdz, dLdx, dim); + + return Status::OK(); + } + + PLATFORM_CHECK(softmax_bp, ENGINE_CPU) { + + auto x = INPUT_VARIABLE(0); + auto dLdz = INPUT_VARIABLE(1); + auto dLdx = OUTPUT_VARIABLE(0); + + const DataType xType = x->dataType(); + const DataType dLdzType = dLdz->dataType(); + const DataType dLdxType = dLdx->dataType(); + + const int xRank = x->rankOf(); + const int dLdzRank = dLdz->rankOf(); + + bool bSupportedRanks = xRank < 7 && dLdzRank == xRank && (!x->isEmpty() && !dLdz->isEmpty()); + + if (bSupportedRanks) { + for (int i = 0; i < xRank; i++) { + if (x->sizeAt(i) != dLdz->sizeAt(i)) { + bSupportedRanks = false; + break; + } + } + } + + //Source Destination + //f32 f32 + return block.isUseMKLDNN() && bSupportedRanks && (xType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32); + } + } } } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp index 9a8bc9f4a..5b08973d9 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp @@ -35,52 +35,21 @@ namespace sd { static void tanhMKLDNN(const NDArray* x, NDArray* z) { const auto xRank = x->rankOf(); + dnnl::memory::dims xShape, zShape; - std::vector dimsX(xRank), dimsZ(xRank); - for (auto i = 0; i < xRank; i++) { - dimsX[i] = x->sizeAt(i); - dimsZ[i] = z->sizeAt(i); - } + mkldnnUtils::getDims(x, xRank, xShape); + mkldnnUtils::getDims(z, xRank, zShape); - dnnl::memory::dims xShape = dnnl::memory::dims(dimsX); - dnnl::memory::dims zShape = dnnl::memory::dims(dimsZ); - - dnnl::memory::format_tag format = dnnl::memory::format_tag::a; - if (2 == xRank) { - format = dnnl::memory::format_tag::ab; - } - else if (3 == xRank) { - format = dnnl::memory::format_tag::abc; - } - else if (4 == xRank) { - format = dnnl::memory::format_tag::abcd; - } - else if (5 == xRank) { - format = dnnl::memory::format_tag::abcde; - } - else if (6 == xRank) { - format = dnnl::memory::format_tag::abcdef; - } + dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - - if (x->ews() != 1 || x->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - for (auto i = 0; i < xRank; ++i) { - x_user_md.data.format_desc.blocking.strides[i] = x->strideAt(i); - } - } + mkldnnUtils::setBlockStrides(x, xRank, x_user_md); // z dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); - if (z->ews() != 1 || z->ordering() != 'c') { - z_user_md.data.format_kind = dnnl_blocked; // overrides format - for (auto i = 0; i < xRank; ++i) { - z_user_md.data.format_desc.blocking.strides[i] = z->strideAt(i); - } - } + mkldnnUtils::setBlockStrides(z, xRank, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -99,12 +68,7 @@ namespace sd { // provide memory buffers and check whether reorder is required // input - auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer()); - const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; + mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); // z auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp index ee5dc9e35..48ea77709 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp @@ -76,3 +76,63 @@ TEST_F(DeclarableOpsTests18, test_tanh_2) { ASSERT_EQ(e, z); } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST) { + + NDArray input('c', { 2, 2 }, { 1,2,3,4 }, DataType::FLOAT32); + NDArray epsilon('c', { 2, 2 }, { .1, .2, .3, .4 }, DataType::FLOAT32); + + int axis = 1; + + NDArray output('c', { 2, 2 }, DataType::FLOAT32); + + NDArray exp('c', { 2, 2 }, { -0.019661, 0.019661, -0.019661, 0.019661 }, DataType::FLOAT32); + + sd::ops::softmax_bp op; + + Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis }); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST2) { + + NDArray input('c', { 4, 5, 2, 3 }, DataType::FLOAT32); + NDArray epsilon('c', { 4, 5, 2, 3 }, { -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855 }, DataType::FLOAT32); + input.linspace(0.1, 0.2); + + int axis = -1; + + NDArray output('c', { 4, 5, 2, 3 }, DataType::FLOAT32); + NDArray exp('c', { 4, 5, 2, 3 }, { -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253 }, DataType::FLOAT32); + + sd::ops::softmax_bp op; + + Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis }); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST3) { + + NDArray input('f', { 4, 5, 2, 3 }, DataType::FLOAT32); + NDArray epsilon('f', { 4, 5, 2, 3 }, { -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855 }, DataType::FLOAT32); + input.linspace(-5., 0.5); + + int axis = 1; + + NDArray output('f', { 4, 5, 2, 3 }, DataType::FLOAT32); + NDArray expC('c', { 4, 5, 2, 3 }, { -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909, -0.000000, 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, -0.000149, 0.000054, 0.000095, 0.000095, -0.000149, 0.000054, -0.001760, 0.002943, -0.001183, -0.001183, -0.001760, 0.002943, 0.001909, -0.002997, 0.001088, 0.001088, 0.001909, -0.002997, 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, 0.000000, 0.000054, 0.000095, -0.000149, -0.000149, 0.000054, 0.000095, 0.002943, -0.001183, -0.001760, -0.001760, 0.002943, -0.001183, -0.002997, 0.001088, 0.001909, 0.001909, -0.002997, 0.001088, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909 }, DataType::FLOAT32); + + NDArray exp('f', { 4, 5, 2, 3 }, DataType::FLOAT32); + exp.assign(expC); + + sd::ops::softmax_bp op; + + Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis }); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); + +} diff --git a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp index c482163d8..2f88e069e 100644 --- a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp @@ -71,8 +71,10 @@ TEST_F(MklDnnTests, helpers_includer) { sd::ops::platforms::PLATFORM_softmax_ENGINE_CPU softmax; + sd::ops::platforms::PLATFORM_softmax_bp_ENGINE_CPU softmax_bp; + sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh; - printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &tanh }); + printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh }); #endif } \ No newline at end of file