diff --git a/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp b/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp index ad7a430f4..dbabad395 100644 --- a/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp @@ -14,10 +14,11 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -// -// xw_plus_b op. Created by GS 31.01.2018 -// -// + // + // xw_plus_b op. Created by GS 31.01.2018 + // @author Oleg Semeniv + // + // #include #if NOT_EXCLUDED(OP_xw_plus_b) @@ -29,36 +30,115 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(xw_plus_b, 3, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); auto z = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(x->rankOf() <= 2 && y->rankOf() <= 2 && z->rankOf() <= 2, 0, "xw_plus_b: Input and Output NDArrays should have rank less or equal to 2"); - REQUIRE_TRUE(b->isVector() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b: Input vector should have proper dimension 1x%i. " - "But %i != %i.", z->sizeAt(-1), b->lengthOf(), z->sizeAt(-1)); + if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty()) + return Status::OK(); + + const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false); + + auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1); + + REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b: Input x array should have rank equal 2, but got instead %i!", x->rankOf()); + REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b: Input weights array should have rank equal 2, but got instead %i!", w->rankOf()); + REQUIRE_TRUE(z->rankOf() == 2, 0, "xw_plus_b: Output array should have rank equal 2, but got instead %i!", z->rankOf()); + + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b: Input bias vector should be 1D and have proper dimension 1x%i." + " But got rank %i, and got length %i instead %i.", z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1)); + // multiply x to y - MmulHelper::mmul(x, y, z, 1.0, 0.0); + MmulHelper::mmul(x, w, z, 1.0, 0.0); // adding b vector z->addiRowVector(*b); + if (bTranspose) + delete w; + return Status::OK(); } DECLARE_SHAPE_FN(xw_plus_b) { - auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), inputShape->at(1), false, false, - ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace()); + + auto weights = INPUT_VARIABLE(1); + + const int nWeightsFormat = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + + auto weightsShape = (1 == nWeightsFormat) ? ShapeUtils::evalTranspShapeInfo(*weights, block.getWorkspace()) : inputShape->at(1); + + auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), weightsShape, false, false, + ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace()); return SHAPELIST(CONSTANT(outputShape)); } DECLARE_TYPES(xw_plus_b) { getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ ALL_FLOATS }); + } + + + CUSTOM_OP_IMPL(xw_plus_b_bp, 4, 3, false, 0, 0) { + + auto x = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(2); + auto dLdz = INPUT_VARIABLE(3); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdb = OUTPUT_VARIABLE(2); + + if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty() || dLdz->isEmpty()) + return Status::OK(); + + const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false); + + auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1); + + REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP: Input x array should have rank equal 2, but got instead %i!", x->rankOf()); + REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP: Input weights array should have rank equal 2, but got instead %i!", w->rankOf()); + REQUIRE_TRUE(dLdz->rankOf() == 2, 0, "xw_plus_b BP: Output array should have rank equal 2, but got instead %i!", dLdz->rankOf()); + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(-1), 0, "xw_plus_b BP: Input bias vector should be 1D and have proper dimension 1x%i." + " But got rank %i, and got length %i instead %i.", dLdz->sizeAt(-1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(-1)); + + auto dLdw = (bTranspose) ? new NDArray(OUTPUT_VARIABLE(1)->transpose()) : OUTPUT_VARIABLE(1); + + // dLdb + dLdb->assign(dLdz->reduceAlongDimension(reduce::Sum, { 0 })); + + matmul_bp mmul_bp; + mmul_bp.execute({ x, w, dLdz }, std::vector{dLdx, dLdw}, {}, {}, {}); + + if (bTranspose) { + delete w; + delete dLdw; + } + return Status::OK(); + } + + DECLARE_SHAPE_FN(xw_plus_b_bp) { + + Nd4jLong* xShapeInfo; + Nd4jLong* wShapeInfo; + Nd4jLong* bShapeInfo; + + COPY_SHAPE(inputShape->at(0), xShapeInfo); + COPY_SHAPE(inputShape->at(1), wShapeInfo); + COPY_SHAPE(inputShape->at(2), bShapeInfo); + + return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(wShapeInfo), CONSTANT(bShapeInfo)); + } + + DECLARE_TYPES(xw_plus_b_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ ALL_FLOATS }); } } } -#endif \ No newline at end of file +#endif diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 81742fa3d..f3131c193 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -867,9 +867,12 @@ namespace sd { * - 2D matrix MxN * - 1D vector with N elements * output value - 2D matrix NxN as multiply of matrixes and add vector + * Int args: + * 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else mmul */ #if NOT_EXCLUDED(OP_xw_plus_b) - DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0); + DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0); + DECLARE_CUSTOM_OP(xw_plus_b_bp, 4, 3, false, 0, 0); #endif /** diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index dd512a884..514a325c7 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -96,6 +96,10 @@ namespace sd { DECLARE_PLATFORM(tanh_bp, ENGINE_CPU); + DECLARE_PLATFORM(xw_plus_b, ENGINE_CPU); + + DECLARE_PLATFORM(xw_plus_b_bp, ENGINE_CPU); + } } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp new file mode 100644 index 000000000..01a003c2c --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp @@ -0,0 +1,426 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleg Semeniv + // + // + +#include +#include +#include +#include +#include "mkldnnUtils.h" + +using namespace dnnl; + +namespace sd { + namespace ops { + namespace platforms { + + ////////////////////////////////////////////////////////////////////// + static void xwPlusBiasMKLDNN(const NDArray* x, const NDArray* weights, const NDArray* bias, NDArray* z, const bool bShouldTransp) { + + // mkl works with following + // [M,K] x [N,K]^T + [N] = [M,N] + const auto xRank = x->rankOf(); + + // [M,K] x [K,N] = [M,N] + const int M = x->sizeAt(0); + const int K = x->sizeAt(1); // K == wK + const int N = z->sizeAt(1); + + dnnl::memory::dims xShape = dnnl::memory::dims({ M, K }); + dnnl::memory::dims wShape = dnnl::memory::dims({ N, K }); + dnnl::memory::dims zShape = dnnl::memory::dims({ M, N }); + dnnl::memory::dims bShape = dnnl::memory::dims({ N }); + + dnnl::memory::format_tag format = dnnl::memory::format_tag::ab; + + // x type + dnnl::memory::data_type xType = dnnl::memory::data_type::f32; + if (x->dataType() == DataType::UINT8) + xType = dnnl::memory::data_type::u8; + else if (x->dataType() == DataType::INT8) + xType = dnnl::memory::data_type::s8; + + // weights type + dnnl::memory::data_type wType = (weights->dataType() == DataType::FLOAT32) ? + wType = dnnl::memory::data_type::f32 : wType = dnnl::memory::data_type::s8; + + // bias type need add description for bias + dnnl::memory::data_type bType = dnnl::memory::data_type::f32; + if (bias->dataType() == DataType::INT32) + bType = dnnl::memory::data_type::s32; + else if (bias->dataType() == DataType::UINT8) + bType = dnnl::memory::data_type::u8; + else if (bias->dataType() == DataType::INT8) + bType = dnnl::memory::data_type::s8; + + // z type + dnnl::memory::data_type zType = dnnl::memory::data_type::f32; + if (z->dataType() == DataType::INT32) + zType = dnnl::memory::data_type::s32; + else if (z->dataType() == DataType::UINT8) + zType = dnnl::memory::data_type::u8; + else if (z->dataType() == DataType::INT8) + zType = dnnl::memory::data_type::s8; + + // memory descriptors for arrays + // x + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); + mkldnnUtils::setBlockStrides(x, x_user_md); + + // weights + dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, wType, format); + if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) { + + weights_user_md.data.format_kind = dnnl_blocked; // overrides format + if (bShouldTransp) { + weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1); + weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0); + } + else { + weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0); + weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1); + } + } + // bias + dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x); + dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x); + mkldnnUtils::setBlockStrides(bias, bias_user_md); + + // z + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format); + mkldnnUtils::setBlockStrides(z, z_user_md); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::inner_product_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, z_mkl_md); + + dnnl::inner_product_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + + // bias + auto bias_mkl_mem = dnnl::memory(bias_mkl_md, engine, bias->getBuffer()); + args[DNNL_ARG_BIAS] = bias_mkl_mem; + + // z + auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; + + // run calculations + dnnl::inner_product_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); + } + + ////////////////////////////////////////////////////////////////////// + static void xwPlusBiasBp(const NDArray* x, const NDArray* weights, const NDArray* bias, const NDArray* dLdz, + NDArray* dLdx, NDArray* dLdw, NDArray* dLdb, const bool bShouldTransp) { + + // mkl works with following + // [M,K] x [N,K]^T + [N] = [M,N] + const auto xRank = x->rankOf(); + + // [M,K] x [K,N] = [M,N] + const int M = x->sizeAt(0); + const int K = x->sizeAt(1); // K == wK + const int N = dLdz->sizeAt(1); + // input dims + dnnl::memory::dims xShape = dnnl::memory::dims({ M, K }); + dnnl::memory::dims wShape = dnnl::memory::dims({ N, K }); + dnnl::memory::dims dLdzShape = dnnl::memory::dims({ M, N }); + + dnnl::memory::dims bShape = dnnl::memory::dims({ N }); + // output dims + dnnl::memory::dims dLdxShape = xShape; + dnnl::memory::dims dLdwShape = wShape; + + dnnl::memory::format_tag format = dnnl::memory::format_tag::ab; + dnnl::memory::data_type dataType = dnnl::memory::data_type::f32; + + // memory descriptors for arrays + // x + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, format); + mkldnnUtils::setBlockStrides(x, x_user_md); + + // weights + dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, dataType, format); + if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) { + + weights_user_md.data.format_kind = dnnl_blocked; // overrides format + if (bShouldTransp) { + weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1); + weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0); + } + else { + weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0); + weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1); + } + } + // bias + dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); + dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); + mkldnnUtils::setBlockStrides(bias, bias_user_md); + + // dLdz + dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dataType, format); + mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); + + // dLdw + dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, format); + dnnl::memory::desc dLdw_user_md = dnnl::memory::desc(wShape, dataType, format); + if (dLdw->ews() != 1 || dLdw->ordering() != 'c' || bShouldTransp) { + + dLdw_user_md.data.format_kind = dnnl_blocked; // overrides format + if (bShouldTransp) { + dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(1); + dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(0); + } + else { + dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(0); + dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(1); + } + } + + // dLdb + dnnl::memory::desc dLdb_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); + dnnl::memory::desc dLdb_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); + mkldnnUtils::setBlockStrides(dLdb, dLdb_user_md); + + // dLdx + dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dataType, format); + mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + // forward + // operation primitive description + dnnl::inner_product_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, dLdz_mkl_md); + dnnl::inner_product_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backprob + // dLdw + auto op_bpdw_desc = inner_product_backward_weights::desc(x_mkl_md, dLdw_mkl_md, dLdb_mkl_md, dLdz_mkl_md); + auto op_bpdw_prim_desc = inner_product_backward_weights::primitive_desc(op_bpdw_desc, engine, op_ff_prim_desc); + + // backprob + // dLdx + auto op_bpdx_desc = inner_product_backward_data::desc(dLdx_mkl_md, weights_mkl_md, dLdz_mkl_md); + auto op_bpdx_prim_desc = inner_product_backward_data::primitive_desc(op_bpdx_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map argsDw, argsDx; + + dnnl::stream stream(engine); + + // dLdz dw + mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdw_prim_desc.diff_dst_desc(), argsDw[DNNL_ARG_DIFF_DST]); + + // dLdz - dx + mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdx_prim_desc.diff_dst_desc(), argsDx[DNNL_ARG_DIFF_DST]); + + // input x for dw + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bpdw_prim_desc.src_desc(), argsDw[DNNL_ARG_SRC]); + + // weights - dx + mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_bpdx_prim_desc.weights_desc(), argsDx[DNNL_ARG_WEIGHTS]); + + // dLdw + auto dLdw_user_mem = dnnl::memory(dLdw_user_md, engine, dLdw->getBuffer()); + const bool dLdwReorder = op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc(); + auto dLdw_mkl_mem = dLdwReorder ? dnnl::memory(op_bpdw_prim_desc.diff_weights_desc(), engine) : dLdw_user_mem; + argsDw[DNNL_ARG_DIFF_WEIGHTS] = dLdw_mkl_mem; + + // dLdx + auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer()); + const bool dLdxReorder = op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc(); + auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_bpdx_prim_desc.diff_src_desc(), engine) : dLdx_user_mem; + argsDx[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; + + // dLdb + auto dLdb_user_mem = dnnl::memory(dLdb_user_md, engine, dLdb->getBuffer()); + const bool dLdbReorder = op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc(); + auto dLdb_mkl_mem = dLdbReorder ? dnnl::memory(op_bpdw_prim_desc.diff_bias_desc(), engine) : dLdb_user_mem; + argsDw[DNNL_ARG_DIFF_BIAS] = dLdb_mkl_mem; + + // run calculations dw + dnnl::inner_product_backward_weights(op_bpdw_prim_desc).execute(stream, argsDw); + // run calculations dx + dnnl::inner_product_backward_data(op_bpdx_prim_desc).execute(stream, argsDx); + + // reorder outputs if necessary + if (dLdxReorder) + dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem); + + if (dLdwReorder) + dnnl::reorder(dLdw_mkl_mem, dLdw_user_mem).execute(stream, dLdw_mkl_mem, dLdw_user_mem); + + if (dLdbReorder) + dnnl::reorder(dLdb_mkl_mem, dLdb_user_mem).execute(stream, dLdb_mkl_mem, dLdb_user_mem); + + stream.wait(); + } + + PLATFORM_IMPL(xw_plus_b, ENGINE_CPU) { + + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto z = OUTPUT_VARIABLE(0); + + if (x->isEmpty() || w->isEmpty() || b->isEmpty()) + return Status::OK(); + + const int xRank = x->rankOf(); + const int wRank = w->rankOf(); + const int zRank = z->rankOf(); + + const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] + + REQUIRE_TRUE(xRank == 2, 0, "xw_plus_b MKL: Input x array should have rank equal 2, but got instead %i!", xRank); + REQUIRE_TRUE(wRank == 2, 0, "xw_plus_b MKL: Input weights array should have rank equal 2, but got instead %i!", wRank); + REQUIRE_TRUE(zRank == 2, 0, "xw_plus_b MKL: Output array should have rank equal 2, but got instead %i!", zRank); + + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b MKL: Input bias vector should be 1D and have proper dimension 1x%i." + " But got rank %i, and got length %i instead %i.", z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1)); + + // mkldnnInerPorductss + xwPlusBiasMKLDNN(x, w, b, z, bShouldTransp); + + return Status::OK(); + } + + PLATFORM_CHECK(xw_plus_b, ENGINE_CPU) { + + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto z = OUTPUT_VARIABLE(0); + + const DataType xType = x->dataType(); + const DataType wType = w->dataType(); + const DataType bType = b->dataType(); + const DataType zType = z->dataType(); + + /* + Source Weights Destination Bias + f32 f32 f32 f32 + u8, s8 s8 u8, s8, s32, f32 u8, s8, s32, f32 + */ + return block.isUseMKLDNN() && + ((xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && bType == DataType::FLOAT32 && zType == DataType::FLOAT32) || + ( // x + (xType == DataType::UINT8 || xType == DataType::INT8) && + // w + (wType == DataType::UINT8 || wType == DataType::INT8) && + // b + (bType == DataType::UINT8 || bType == DataType::INT8 || bType == DataType::INT32 || bType == DataType::FLOAT32) && + // z + (zType == DataType::UINT8 || zType == DataType::INT8 || zType == DataType::INT32 || zType == DataType::FLOAT32) + )); + } + + PLATFORM_IMPL(xw_plus_b_bp, ENGINE_CPU) { + + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto dLdz = INPUT_VARIABLE(3); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdw = OUTPUT_VARIABLE(1); + auto dLdb = OUTPUT_VARIABLE(2); + + if (x->isEmpty() || w->isEmpty() || b->isEmpty() || dLdz->isEmpty()) + return Status::OK(); + + const int xRank = x->rankOf(); + const int wRank = w->rankOf(); + const int dLdzRank = dLdz->rankOf(); + + const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] + + REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP MKL: Input x array should have rank equal 2, but got instead %i!", x->rankOf()); + REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP MKL: Input weights array should have rank equal 2, but got instead %i!", w->rankOf()); + REQUIRE_TRUE(dLdz->rankOf() == 2, 0, "xw_plus_b BP MKL: Output array should have rank equal 2, but got instead %i!", dLdz->rankOf()); + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(1), 0, "xw_plus_b BP MKL: Input bias vector should be 1D and have proper dimension 1x%i." + " But got rank %i, and got length %i instead %i.", dLdz->sizeAt(1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(1)); + + xwPlusBiasBp(x, w, b, dLdz, dLdx, dLdw, dLdb, bShouldTransp); + + return Status::OK(); + } + + PLATFORM_CHECK(xw_plus_b_bp, ENGINE_CPU) { + + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto dLdz = INPUT_VARIABLE(3); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdw = OUTPUT_VARIABLE(1); + auto dLdb = OUTPUT_VARIABLE(2); + + const DataType xType = x->dataType(); + const DataType wType = w->dataType(); + const DataType bType = b->dataType(); + const DataType dLdzType = dLdz->dataType(); + const DataType dLdxType = dLdx->dataType(); + const DataType dLdwType = dLdw->dataType(); + const DataType dLdbType = dLdb->dataType(); + + /* + Source Weights Destination Bias + f32 f32 f32 f32 + */ + return block.isUseMKLDNN() && + (xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && + bType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && + dLdbType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32 && + dLdwType == DataType::FLOAT32); + } + + } + } +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp index b1cafa073..1f36a8f2c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp @@ -15,9 +15,9 @@ ******************************************************************************/ -// -// @author raver119@gmail.com -// + // + // @author raver119@gmail.com + // #include "testlayers.h" #include @@ -45,19 +45,19 @@ TEST_F(DeclarableOpsTests18, test_bitcast_1) { auto e = NDArrayFactory::create(4597464930322771456L); sd::ops::bitcast op; - auto status = op.execute({&x}, {&z}, {}, {(Nd4jLong) sd::DataType::INT64}, {}); + auto status = op.execute({ &x }, { &z }, {}, { (Nd4jLong)sd::DataType::INT64 }, {}); ASSERT_EQ(Status::OK(), status); ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests18, test_tanh_1) { - auto x = NDArrayFactory::create('c', {8}, {0.23f, -0.23f, 0.35f, -0.35f, 0.64f, -0.64f, 100000.f, -100000.f}); + auto x = NDArrayFactory::create('c', { 8 }, { 0.23f, -0.23f, 0.35f, -0.35f, 0.64f, -0.64f, 100000.f, -100000.f }); auto z = x.ulike(); - auto e = NDArrayFactory::create('c', {8}, {0.226028f, -0.226028f, 0.336376f, -0.336376f, 0.564900f, -0.564900f, 1.f, -1.f}); + auto e = NDArrayFactory::create('c', { 8 }, { 0.226028f, -0.226028f, 0.336376f, -0.336376f, 0.564900f, -0.564900f, 1.f, -1.f }); sd::ops::tanh op; - op.execute({&x}, {&z}); + op.execute({ &x }, { &z }); ASSERT_EQ(e, z); } @@ -187,6 +187,197 @@ TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST3) { ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(output.equalsTo(exp)); } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_1) { + + auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto w = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + NDArray dLdz('c', { 2, 2 }, DataType::FLOAT32); + dLdz.linspace(1); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 2,3 }, { 17.f, 14.f, 10.f, 45.f, 32.f, 26.f }); + auto edLdw = NDArrayFactory::create('c', { 3,2 }, { 43.f, 58.f, 26.f, 42.f, 21.f, 30.f }); + auto edLdb = NDArrayFactory::create('c', { 2 }, { 4.f, 6.f }); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_2) { + + auto x = NDArrayFactory::create('c', { 6,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto w = NDArrayFactory::create('c', { 3,4 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f, 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create('c', { 4 }, { 100.f, 200.f, 100.f, 200.f }); + + NDArray dLdz('c', { 6, 4 }, DataType::FLOAT32); + dLdz.linspace(.1, .5); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 6,3 }, { 15.3f, 18.700001f, 13.2f, 61.299995f, 62.699997f, 47.200001f, 107.299995f, 106.699997f, 81.199997f, 153.299988f, 150.699997f, 115.199997f, 199.300018f, 194.700012f, 149.199997f, 245.300018f, 238.700012f, 183.199997f }); + auto edLdw = NDArrayFactory::create('c', { 3,4 }, { 268.5f, 291.f, 313.5f, 336.f, 226.800003f, 250.800003f, 274.799988f, 298.799988f, 146.699997f, 160.199997f, 173.700012f, 187.200012f }); + auto edLdb = NDArrayFactory::create('c', { 4 }, { 30.6f, 33.599998f, 36.599998f, 39.599998f }); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_3) { + + auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); + auto w = NDArrayFactory::create('c', { 2, 3 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f, 300.f }); + + auto dLdz = NDArrayFactory::create('c', { 1, 3 }, { 166.f, 269.f, 326.f }); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 1,2 }, { 3937.f, 3096.f }); + auto edLdw = NDArrayFactory::create('c', { 2,3 }, { 166.f, 269.f, 326.f, 1826.f, 2959.f, 3586.f }); + auto edLdb = NDArrayFactory::create('c', { 3 }, { 166.f, 269.f, 326.f }); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_4) { + + auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); + auto w = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); + auto b = NDArrayFactory::create('c', { 1 }, { 200.f }); + + auto dLdz = NDArrayFactory::create('c', { 1,1 }, { 244.f }); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 1,2 }, { 2684.f, 732.f }); + auto edLdw = NDArrayFactory::create('c', { 2,1 }, { 244.f, 2684.f }); + auto edLdb = NDArrayFactory::create('c', { 1 }, { 244.f }); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_5) { + + auto x = NDArrayFactory::create('f', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto w = NDArrayFactory::create('f', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + auto dLdz = NDArrayFactory::create('f', { 2,2 }, { 140.f, 287.f, 233.f, 351.f }); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdxC = NDArrayFactory::create('c', { 2,3 }, { 2705.f, 1818.f, 1026.f, 4912.f, 2967.f, 1850.f }); + auto edLdwC = NDArrayFactory::create('c', { 3,2 }, { 3297.f, 4094.f, 4438.f, 5613.f, 2422.f, 3271.f }); + auto edLdbC = NDArrayFactory::create('c', { 2 }, { 427.f, 584.f }); + + auto edLdx = NDArrayFactory::create('f', { 2,3 }); + auto edLdw = NDArrayFactory::create('f', { 3,2 }); + auto edLdb = NDArrayFactory::create('f', { 2 }); + + edLdx.assign(edLdxC); + edLdw.assign(edLdwC); + edLdb.assign(edLdbC); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_6) { + + auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto w = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + auto dLdz = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); + + // mkl-format + w.permutei({ 1,0 }); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, { 1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 2,3 }, { 2695.f, 2012.f, 1566.f, 4247.f, 2635.f, 2418.f }); + auto edLdwC = NDArrayFactory::create('c', { 3,2 }, { 4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f }); + auto edLdb = NDArrayFactory::create('c', { 2 }, { 483.f, 543.f }); + auto edLdw = NDArrayFactory::create('c', { 3,2 }, { 4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f }); + edLdw.permutei({ 1,0 }); + edLdw.assign(edLdwC); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); +} ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterSgd1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 8958f9023..6ac9d34cd 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -2432,18 +2432,36 @@ TEST_F(DeclarableOpsTests5, ZeroFraction_3) { } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_1) { - auto x = NDArrayFactory::create('c', {2,3}, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); - auto y = NDArrayFactory::create('c', {3,2}, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); - auto b = NDArrayFactory::create({100.f, 200.f}); + auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); - auto exp = NDArrayFactory::create('c', {2,2}, {173.f, 264.f, 310.f, 279.f}); + auto exp = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); sd::ops::xw_plus_b op; - auto result = op.evaluate({&x, &y, &b}, {}, {}); + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_2) { + + auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); + auto y = NDArrayFactory::create('c', { 2, 3 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f, 300.f }); + + auto exp = NDArrayFactory::create('c', { 1, 3 }, { 166.f, 269.f, 326.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); @@ -2452,9 +2470,107 @@ TEST_F(DeclarableOpsTests5, XWPlusB_1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_3) { + auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); + auto y = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); + auto b = NDArrayFactory::create('c', { 1 }, { 200.f }); + + auto exp = NDArrayFactory::create('c', { 1,1 }, { 244.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_4) { + + auto x = NDArrayFactory::create('f', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('f', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + auto exp = NDArrayFactory::create('f', { 2,2 }, { 140.f, 287.f, 233.f, 351.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_5) { + + auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + + y = y.transpose(); + + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + auto exp = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); + + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }, {}, { 1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_6) { + + auto x = NDArrayFactory::create('c', { 3, 2 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); + + auto b = NDArrayFactory::create('c', { 1 }, { 100.f }); + + auto exp = NDArrayFactory::create('c', { 3, 1 }, { 144.f, 175.f, 173.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_7) { + + auto x = NDArrayFactory::create('c', { 3, 4 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('c', { 4, 5 }, { 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 3.f, 11.f, 3.f, 11.f }); + + auto b = NDArrayFactory::create('c', { 5 }, { 100.f, 200.f, 300.f, 400.f, 500.f }); + + auto exp = NDArrayFactory::create('c', { 3, 5 }, { 219.f, 375.f, 531.f, 575.f, 731.f, 217.f, 317.f, 505.f, 517.f, 705.f, 248.f, 396.f, 496.f, 596.f, 696.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, StopGradient_1) { diff --git a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp index dcbfa29b0..bb3934994 100644 --- a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp @@ -76,8 +76,13 @@ TEST_F(MklDnnTests, helpers_includer) { sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh; sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh_bp; + + sd::ops::platforms::PLATFORM_xw_plus_b_ENGINE_CPU xw_plus_b; - printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp }); + sd::ops::platforms::PLATFORM_xw_plus_b_bp_ENGINE_CPU xw_plus_b_bp; + + printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp, &xw_plus_b, &xw_plus_b_bp }); + #endif -} \ No newline at end of file +}