diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index 1237baac0..29b5ebf2a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -94,6 +94,8 @@ namespace sd { DECLARE_PLATFORM(tanh, ENGINE_CPU); + DECLARE_PLATFORM(tanh_bp, ENGINE_CPU); + } } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp index 5b08973d9..5a3ab0f57 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp @@ -109,12 +109,126 @@ namespace sd { const DataType zType = z->dataType(); const int xRank = x->rankOf(); - bool bSupportedRanks = xRank < 7; + bool bSupportedRanks = !x->isEmpty() && xRank < 7 && (xType == DataType::FLOAT32 && zType == DataType::FLOAT32); /* Source Destination f32 f32 */ - return !x->isEmpty() && block.isUseMKLDNN() && bSupportedRanks && (xType == DataType::FLOAT32 && zType == DataType::FLOAT32); + return block.isUseMKLDNN() && bSupportedRanks; + } + + + ////////////////////////////////////////////////////////////////////// + static void tanhBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx) { + + const auto xRank = x->rankOf(); + dnnl::memory::dims xShape, dLdzShape, dLdxShape; + + mkldnnUtils::getDims(x, xRank, xShape); + mkldnnUtils::getDims(dLdz, xRank, dLdzShape); + mkldnnUtils::getDims(dLdx, xRank, dLdxShape); + + dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); + + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + mkldnnUtils::setBlockStrides(x, xRank, x_user_md); + + // dLdz + dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + mkldnnUtils::setBlockStrides(dLdz, xRank, dLdz_user_md); + + // dLdx + dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + mkldnnUtils::setBlockStrides(dLdx, xRank, dLdx_user_md); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // operation primitive description + // forward + dnnl::eltwise_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, algorithm::eltwise_tanh, x_mkl_md, 0, 0); + dnnl::eltwise_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward description + dnnl::eltwise_backward::desc op_desc(algorithm::eltwise_tanh, dLdz_mkl_md, x_mkl_md, 0, 0); + dnnl::eltwise_backward::primitive_desc op_prim_desc(op_desc, engine, op_ff_prim_desc); + + // provide memory buffers and check whether reorder is required for forward + // input + mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); + + // dLdz + mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, args, dLdz_user_md, op_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); + + // dLdx + auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer()); + const bool dLdxReorder = op_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc(); + auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_prim_desc.diff_src_desc(), engine) : dLdx_user_mem; + args[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; + + // run calculations backward + dnnl::eltwise_backward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (dLdxReorder) + dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem); + + stream.wait(); + } + + + PLATFORM_IMPL(tanh_bp, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); + auto dLdz = INPUT_VARIABLE(1); + auto dLdx = OUTPUT_VARIABLE(0); + + const int rank = input->rankOf(); + const int dLdzRank = dLdz->rankOf(); + + REQUIRE_TRUE(rank <= 6 && dLdzRank <= 6, 0, "TANH_BP_MKLDNN OP: the rank of input and dLdz must be less or qual 6, but got input rank = %i and dLdz rank rank = %i instead !", rank, dLdzRank); + + // mkldnnSoftMax + tanhBpMKLDNN(input, dLdz, dLdx); + + return Status::OK(); + } + + PLATFORM_CHECK(tanh_bp, ENGINE_CPU) { + + auto x = INPUT_VARIABLE(0); + auto dLdz = INPUT_VARIABLE(1); + auto dLdx = OUTPUT_VARIABLE(0); + + const DataType xType = x->dataType(); + const DataType dLdzType = dLdz->dataType(); + const DataType dLdxType = dLdx->dataType(); + + const int xRank = x->rankOf(); + const int dLdzRank = dLdz->rankOf(); + + bool bSupportedRanks = xRank < 7 && dLdzRank == xRank && (!x->isEmpty() && !dLdz->isEmpty()); + bSupportedRanks &= (xType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32); + + if (bSupportedRanks) { + for (int i = 0; i < xRank; i++) { + if (x->sizeAt(i) != dLdz->sizeAt(i)) { + bSupportedRanks = false; + break; + } + } + } + + //Source Destination + //f32 f32 + return block.isUseMKLDNN() && bSupportedRanks; } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp index 48ea77709..f8de783c9 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp @@ -75,7 +75,59 @@ TEST_F(DeclarableOpsTests18, test_tanh_2) { op.execute({ &x }, { &z }); ASSERT_EQ(e, z); } +///////////////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, test_tanh_bp) { + NDArray x('c', { 2, 3, 4 }, sd::DataType::FLOAT32); + NDArray dLdz('c', { 2, 3, 4 }, sd::DataType::FLOAT32); + NDArray dLdx('c', { 2, 3, 4 }, sd::DataType::FLOAT32); + + x.linspace(-1., 0.003); + dLdz.linspace(0.01, 0.01); + + NDArray e('c', { 2, 3, 4 }, { 0.004200, 0.008438, 0.012715, 0.017030, 0.021385, 0.025778, 0.030211, 0.034684, 0.039195, 0.043747, 0.048339, 0.052970, 0.057642, 0.062354, 0.067107, 0.071901, 0.076735, 0.081610, 0.086527, 0.091485, 0.096484, 0.101525, 0.106608, 0.111732 }, sd::DataType::FLOAT32); + + sd::ops::tanh_bp op; + op.execute({ &x, &dLdz }, { &dLdx }); + ASSERT_EQ(e, dLdx); +} +///////////////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, test_tanh_bp2) { + + NDArray x('f', { 2, 3, 4 }, sd::DataType::FLOAT32); + NDArray dLdz('f', { 2, 3, 4 }, sd::DataType::FLOAT32); + NDArray dLdx('f', { 2, 3, 4 }, sd::DataType::FLOAT32); + + x.linspace(-1., 0.003); + dLdz.linspace(0.01, 0.01); + + NDArray exp('c', { 2, 3, 4 }, { 0.004200, 0.008438, 0.012715, 0.017030, 0.021385, 0.025778, 0.030211, 0.034684, 0.039195, 0.043747, 0.048339, 0.052970, 0.057642, 0.062354, 0.067107, 0.071901, 0.076735, 0.081610, 0.086527, 0.091485, 0.096484, 0.101525, 0.106608, 0.111732 }, sd::DataType::FLOAT32); + NDArray e('f', { 2, 3, 4 }, sd::DataType::FLOAT32); + e.assign(exp); + + sd::ops::tanh_bp op; + op.execute({ &x, &dLdz }, { &dLdx }); + ASSERT_EQ(e, dLdx); +} +///////////////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, test_tanh_bp3) { + + NDArray x('f', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); + NDArray dLdz('f', { 2,2, 3,3, 4,4 }, sd::DataType::FLOAT32); + NDArray dLdx('f', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); + + x.linspace(-1.5, 0.005); + dLdz.linspace(-1., 0.01); + + NDArray exp('c', { 2, 2, 3, 3, 4, 4 }, { -0.180707, -0.180525, -0.180324, -0.180103, -0.179861, -0.179599, -0.179315, -0.179009, -0.178682, -0.178333, -0.177961, -0.177566, -0.177148, -0.176706, -0.176240, -0.175750, -0.175236, -0.174696, -0.174130, -0.173539, -0.172922, -0.172278, -0.171607, -0.170909, -0.170183, -0.169429, -0.168646, -0.167834, -0.166993, -0.166123, -0.165222, -0.164290, -0.163327, -0.162334, -0.161308, -0.160250, -0.159159, -0.158035, -0.156877, -0.155686, -0.154460, -0.153199, -0.151903, -0.150571, -0.149203, -0.147798, -0.146356, -0.144876, -0.143359, -0.141803, -0.140207, -0.138573, -0.136898, -0.135183, -0.133428, -0.131630, -0.129792, -0.127910, -0.125986, -0.124019, -0.122008, -0.119953, -0.117853, -0.115708, -0.113517, -0.111279, -0.108996, -0.106665, -0.104286, -0.101859, -0.099383, -0.096859, -0.094284, -0.091660, -0.088984, -0.086258, -0.083480, -0.080649, -0.077766, -0.074830, -0.071840, -0.068796, -0.065697, -0.062543, -0.059334, -0.056068, -0.052745, -0.049365, -0.045928, -0.042432, -0.038878, -0.035264, -0.031591, -0.027858, -0.024064, -0.020209, -0.016292, -0.012313, -0.008272, -0.004168, 0.000000, 0.004232, 0.008528, 0.012889, 0.017316, 0.021808, 0.026367, 0.030992, 0.035684, 0.040444, 0.045272, 0.050169, 0.055134, 0.060168, 0.065273, 0.070447, 0.075692, 0.081007, 0.086394, 0.091853, 0.097383, 0.102986, 0.108662, 0.114411, 0.120233, 0.126129, 0.132099, 0.138144, 0.144263, 0.150457, 0.156727, 0.163072, 0.169493, 0.175990, 0.182564, 0.189214, 0.195941, 0.202745, 0.209627, 0.216585, 0.223622, 0.230736, 0.237929, 0.245200, 0.252549, 0.259976, 0.267482, 0.275066, 0.282730, 0.290472, 0.298293, 0.306193, 0.314172, 0.322230, 0.330366, 0.338582, 0.346877, 0.355250, 0.363703, 0.372234, 0.380844, 0.389532, 0.398299, 0.407144, 0.416067, 0.425068, 0.434147, 0.443303, 0.452537, 0.461848, 0.471235, 0.480699, 0.490240, 0.499856, 0.509548, 0.519314, 0.529156, 0.539072, 0.549062, 0.559126, 0.569262, 0.579471, 0.589753, 0.600106, 0.610530, 0.621024, 0.631588, 0.642222, 0.652924, 0.663694, 0.674532, 0.685436, 0.696406, 0.707441, 0.718541, 0.729704, 0.740931, 0.752219, 0.763568, 0.774978, 0.786448, 0.797976, 0.809561, 0.821203, 0.832901, 0.844654, 0.856460, 0.868319, 0.880230, 0.892191, 0.904201, 0.916260, 0.928366, 0.940518, 0.952715, 0.964955, 0.977238, 0.989561, 1.001925, 1.014327, 1.026767, 1.039242, 1.051752, 1.064295, 1.076870, 1.089475, 1.102109, 1.114771, 1.127459, 1.140171, 1.152907, 1.165664, 1.178441, 1.191237, 1.204050, 1.216878, 1.229720, 1.242573, 1.255438, 1.268311, 1.281192, 1.294078, 1.306968, 1.319860, 1.332753, 1.345644, 1.358533, 1.371417, 1.384294, 1.397163, 1.410022, 1.422870, 1.435704, 1.448522, 1.461323, 1.474105, 1.486867, 1.499606, 1.512321, 1.525009, 1.537669, 1.550299, 1.562897, 1.575462, 1.587991, 1.600483, 1.612935, 1.625347, 1.637715, 1.650040, 1.662317, 1.674545, 1.686724, 1.698850, 1.710922, 1.722939, 1.734897, 1.746797, 1.758635, 1.770409, 1.782119, 1.793762, 1.805337, 1.816842, 1.828274, 1.839633, 1.850916, 1.862121, 1.873248, 1.884294, 1.895258, 1.906137, 1.916931, 1.927637, 1.938255, 1.948782, 1.959216, 1.969557, 1.979802, 1.989950, 2.000000, 2.009950, 2.019798, 2.029543, 2.039184, 2.048719, 2.058147, 2.067466, 2.076675, 2.085773, 2.094759, 2.103630, 2.112386, 2.121026, 2.129548, 2.137952, 2.146235, 2.154397, 2.162437, 2.170354, 2.178146, 2.185813, 2.193353, 2.200766, 2.208051, 2.215207, 2.222232, 2.229127, 2.235889, 2.242520, 2.249017, 2.255379, 2.261607, 2.267699, 2.273656, 2.279475, 2.285158, 2.290702, 2.296108, 2.301376, 2.306503, 2.311491, 2.316339, 2.321046, 2.325613, 2.330038, 2.334321, 2.338464, 2.342464, 2.346322, 2.350037, 2.353610, 2.357041, 2.360329, 2.363475, 2.366478, 2.369338, 2.372056, 2.374632, 2.377065, 2.379356, 2.381505, 2.383512, 2.385378, 2.387103, 2.388686, 2.390128, 2.391431, 2.392593, 2.393615, 2.394499, 2.395244, 2.395850, 2.396319, 2.396650, 2.396845, 2.396904, 2.396826, 2.396615, 2.396268, 2.395789, 2.395176, 2.394431, 2.393554, 2.392547, 2.391410, 2.390144, 2.388749, 2.387227, 2.385578, 2.383804, 2.381904, 2.379880, 2.377734, 2.375465, 2.373075, 2.370565, 2.367936, 2.365188, 2.362324, 2.359343, 2.356247, 2.353038, 2.349715, 2.346280, 2.342735, 2.339080, 2.335316, 2.331445, 2.327468, 2.323386, 2.319200, 2.314912, 2.310522, 2.306031, 2.301442, 2.296754, 2.291970, 2.287090, 2.282116, 2.277049, 2.271890, 2.266641, 2.261302, 2.255876, 2.250362, 2.244763, 2.239080, 2.233314, 2.227467, 2.221538, 2.215531, 2.209445, 2.203284, 2.197047, 2.190736, 2.184352, 2.177897, 2.171371, 2.164777, 2.158115, 2.151386, 2.144592, 2.137735, 2.130815, 2.123833, 2.116792, 2.109692, 2.102533, 2.095320, 2.088051, 2.080727, 2.073352, 2.065925, 2.058447, 2.050921, 2.043347, 2.035727, 2.028061, 2.020351, 2.012599, 2.004804, 1.996969, 1.989094, 1.981181, 1.973232, 1.965246, 1.957225, 1.949171, 1.941084, 1.932965, 1.924816, 1.916638, 1.908432, 1.900198, 1.891938, 1.883654, 1.875345, 1.867014, 1.858661, 1.850286, 1.841892, 1.833479, 1.825048, 1.816600, 1.808136, 1.799657, 1.791165, 1.782659, 1.774141, 1.765612, 1.757073, 1.748523, 1.739967, 1.731401, 1.722829, 1.714251, 1.705668, 1.697082, 1.688491, 1.679897, 1.671302, 1.662707, 1.654110, 1.645514, 1.636920, 1.628328, 1.619738, 1.611152, 1.602570, 1.593993, 1.585422, 1.576857, 1.568299, 1.559749, 1.551207, 1.542674, 1.534151, 1.525638, 1.517136, 1.508645, 1.500167, 1.491701, 1.483248, 1.474810, 1.466385, 1.457976, 1.449581, 1.441203, 1.432841, 1.424496, 1.416169, 1.407860, 1.399569, 1.391297, 1.383045, 1.374812, 1.366600, 1.358408, 1.350237, 1.342088, 1.333961, 1.325856, 1.317774, 1.309715, 1.301679, 1.293668, 1.285680, 1.277718, 1.269780, 1.261867, 1.253980, 1.246119, 1.238283, 1.230474, 1.222692, 1.214937, 1.207210, 1.199510, 1.191837, 1.184193, 1.176577, 1.168990, 1.161430, 1.153901, 1.146401, 1.138930, 1.131489, 1.124077, 1.116696, 1.109345, 1.102024, 1.094734, 1.087475, 1.080246, 1.073049 }, sd::DataType::FLOAT32); + + NDArray e('f', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); + e.assign(exp); + + sd::ops::tanh_bp op; + op.execute({ &x, &dLdz }, { &dLdx }); + ASSERT_EQ(e, dLdx); +} ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST) { @@ -134,5 +186,4 @@ TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST3) { Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis }); ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(output.equalsTo(exp)); - } diff --git a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp index 2f88e069e..dcbfa29b0 100644 --- a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp @@ -75,6 +75,9 @@ TEST_F(MklDnnTests, helpers_includer) { sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh; - printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh }); + sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh_bp; + + printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp }); + #endif } \ No newline at end of file