xw_plus_b mkldnn implementation (#247)

* libnd4j first step of mkldnn for xw_plus_b and test of aurora crash in imageHelper

* libnd4j sync folders with master

* libnd4j merge master, raw implementation of xw_plus_b on mkldnn, clean up, need testing and adding checks for corresponded input shapes

* libnd4j corrections and checks added to xw_plus_b mkl

* libnd4j corrected dataType description based on mkl operation description, need more investigation

* libnd4j fixe xw_blus_b mkl implementation, need testing

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j two unit tests added

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed check input dimensions bug

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libndj4 one more test added to cover different order handling

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j added optional int arg support to define weights format, if arg == 1, mkldnn (do not need transpose in mkldnn implementation), else mmul weights format, corrected check points, added unit test

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j merge master

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some improvements to avoid NDArray transpose in xw_plus_b operation

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed issues connected with weights rank, also added support of one case based on tf (for mkldnn, cpu, cuda), test case added

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j added proper handling of empty inputs (all implementations)

* libnd4j fixed compilation error

* libnd4j several more corrections after conflict solve and fixed typos

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j removed unsupported data types

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j merge master and fixed issues

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j added propagation implementation for xw_plus_b, fixed issue connected with mkl weights data format, avoided data copy in transpose mode, test cases added, manually tested with gradCheck

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j one minor fix of double operation declaration

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j code clean up

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j minor tests fixes

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed build problem, integrate helpers changes

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
Oleh 2020-03-31 13:03:10 +03:00 committed by GitHub
parent 29e61579c1
commit 1d004b542a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 856 additions and 31 deletions

View File

@ -14,10 +14,11 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// xw_plus_b op. Created by GS <george@skymind.io> 31.01.2018
//
//
//
// xw_plus_b op. Created by GS <george@skymind.io> 31.01.2018
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
//
//
#include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_xw_plus_b)
@ -29,34 +30,113 @@
namespace sd {
namespace ops {
CUSTOM_OP_IMPL(xw_plus_b, 3, 1, false, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto b = INPUT_VARIABLE(2);
auto z = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(x->rankOf() <= 2 && y->rankOf() <= 2 && z->rankOf() <= 2, 0, "xw_plus_b: Input and Output NDArrays should have rank less or equal to 2");
REQUIRE_TRUE(b->isVector() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b: Input vector should have proper dimension 1x%i. "
"But %i != %i.", z->sizeAt(-1), b->lengthOf(), z->sizeAt(-1));
if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty())
return Status::OK();
const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false);
auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1);
REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b: Input x array should have rank equal 2, but got instead %i!", x->rankOf());
REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b: Input weights array should have rank equal 2, but got instead %i!", w->rankOf());
REQUIRE_TRUE(z->rankOf() == 2, 0, "xw_plus_b: Output array should have rank equal 2, but got instead %i!", z->rankOf());
REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b: Input bias vector should be 1D and have proper dimension 1x%i."
" But got rank %i, and got length %i instead %i.", z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1));
// multiply x to y
MmulHelper::mmul(x, y, z, 1.0, 0.0);
MmulHelper::mmul(x, w, z, 1.0, 0.0);
// adding b vector
z->addiRowVector(*b);
if (bTranspose)
delete w;
return Status::OK();
}
DECLARE_SHAPE_FN(xw_plus_b) {
auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), inputShape->at(1), false, false,
ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace());
auto weights = INPUT_VARIABLE(1);
const int nWeightsFormat = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
auto weightsShape = (1 == nWeightsFormat) ? ShapeUtils::evalTranspShapeInfo(*weights, block.getWorkspace()) : inputShape->at(1);
auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), weightsShape, false, false,
ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace());
return SHAPELIST(CONSTANT(outputShape));
}
DECLARE_TYPES(xw_plus_b) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ ALL_FLOATS });
}
CUSTOM_OP_IMPL(xw_plus_b_bp, 4, 3, false, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(2);
auto dLdz = INPUT_VARIABLE(3);
auto dLdx = OUTPUT_VARIABLE(0);
auto dLdb = OUTPUT_VARIABLE(2);
if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty() || dLdz->isEmpty())
return Status::OK();
const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false);
auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1);
REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP: Input x array should have rank equal 2, but got instead %i!", x->rankOf());
REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP: Input weights array should have rank equal 2, but got instead %i!", w->rankOf());
REQUIRE_TRUE(dLdz->rankOf() == 2, 0, "xw_plus_b BP: Output array should have rank equal 2, but got instead %i!", dLdz->rankOf());
REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(-1), 0, "xw_plus_b BP: Input bias vector should be 1D and have proper dimension 1x%i."
" But got rank %i, and got length %i instead %i.", dLdz->sizeAt(-1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(-1));
auto dLdw = (bTranspose) ? new NDArray(OUTPUT_VARIABLE(1)->transpose()) : OUTPUT_VARIABLE(1);
// dLdb
dLdb->assign(dLdz->reduceAlongDimension(reduce::Sum, { 0 }));
matmul_bp mmul_bp;
mmul_bp.execute({ x, w, dLdz }, std::vector<NDArray*>{dLdx, dLdw}, {}, {}, {});
if (bTranspose) {
delete w;
delete dLdw;
}
return Status::OK();
}
DECLARE_SHAPE_FN(xw_plus_b_bp) {
Nd4jLong* xShapeInfo;
Nd4jLong* wShapeInfo;
Nd4jLong* bShapeInfo;
COPY_SHAPE(inputShape->at(0), xShapeInfo);
COPY_SHAPE(inputShape->at(1), wShapeInfo);
COPY_SHAPE(inputShape->at(2), bShapeInfo);
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(wShapeInfo), CONSTANT(bShapeInfo));
}
DECLARE_TYPES(xw_plus_b_bp) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ ALL_FLOATS });
}
}
}

View File

@ -867,9 +867,12 @@ namespace sd {
* - 2D matrix MxN
* - 1D vector with N elements
* output value - 2D matrix NxN as multiply of matrixes and add vector
* Int args:
* 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else mmul
*/
#if NOT_EXCLUDED(OP_xw_plus_b)
DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0);
DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0);
DECLARE_CUSTOM_OP(xw_plus_b_bp, 4, 3, false, 0, 0);
#endif
/**

View File

@ -96,6 +96,10 @@ namespace sd {
DECLARE_PLATFORM(tanh_bp, ENGINE_CPU);
DECLARE_PLATFORM(xw_plus_b, ENGINE_CPU);
DECLARE_PLATFORM(xw_plus_b_bp, ENGINE_CPU);
}
}

View File

@ -0,0 +1,426 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
//
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <system/platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
using namespace dnnl;
namespace sd {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////
static void xwPlusBiasMKLDNN(const NDArray* x, const NDArray* weights, const NDArray* bias, NDArray* z, const bool bShouldTransp) {
// mkl works with following
// [M,K] x [N,K]^T + [N] = [M,N]
const auto xRank = x->rankOf();
// [M,K] x [K,N] = [M,N]
const int M = x->sizeAt(0);
const int K = x->sizeAt(1); // K == wK
const int N = z->sizeAt(1);
dnnl::memory::dims xShape = dnnl::memory::dims({ M, K });
dnnl::memory::dims wShape = dnnl::memory::dims({ N, K });
dnnl::memory::dims zShape = dnnl::memory::dims({ M, N });
dnnl::memory::dims bShape = dnnl::memory::dims({ N });
dnnl::memory::format_tag format = dnnl::memory::format_tag::ab;
// x type
dnnl::memory::data_type xType = dnnl::memory::data_type::f32;
if (x->dataType() == DataType::UINT8)
xType = dnnl::memory::data_type::u8;
else if (x->dataType() == DataType::INT8)
xType = dnnl::memory::data_type::s8;
// weights type
dnnl::memory::data_type wType = (weights->dataType() == DataType::FLOAT32) ?
wType = dnnl::memory::data_type::f32 : wType = dnnl::memory::data_type::s8;
// bias type need add description for bias
dnnl::memory::data_type bType = dnnl::memory::data_type::f32;
if (bias->dataType() == DataType::INT32)
bType = dnnl::memory::data_type::s32;
else if (bias->dataType() == DataType::UINT8)
bType = dnnl::memory::data_type::u8;
else if (bias->dataType() == DataType::INT8)
bType = dnnl::memory::data_type::s8;
// z type
dnnl::memory::data_type zType = dnnl::memory::data_type::f32;
if (z->dataType() == DataType::INT32)
zType = dnnl::memory::data_type::s32;
else if (z->dataType() == DataType::UINT8)
zType = dnnl::memory::data_type::u8;
else if (z->dataType() == DataType::INT8)
zType = dnnl::memory::data_type::s8;
// memory descriptors for arrays
// x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
mkldnnUtils::setBlockStrides(x, x_user_md);
// weights
dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, wType, format);
if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) {
weights_user_md.data.format_kind = dnnl_blocked; // overrides format
if (bShouldTransp) {
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1);
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0);
}
else {
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0);
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1);
}
}
// bias
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x);
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x);
mkldnnUtils::setBlockStrides(bias, bias_user_md);
// z
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format);
mkldnnUtils::setBlockStrides(z, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// operation primitive description
dnnl::inner_product_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, z_mkl_md);
dnnl::inner_product_forward::primitive_desc op_prim_desc(op_desc, engine);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, dnnl::memory> args;
dnnl::stream stream(engine);
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// bias
auto bias_mkl_mem = dnnl::memory(bias_mkl_md, engine, bias->getBuffer());
args[DNNL_ARG_BIAS] = bias_mkl_mem;
// z
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
// run calculations
dnnl::inner_product_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
stream.wait();
}
//////////////////////////////////////////////////////////////////////
static void xwPlusBiasBp(const NDArray* x, const NDArray* weights, const NDArray* bias, const NDArray* dLdz,
NDArray* dLdx, NDArray* dLdw, NDArray* dLdb, const bool bShouldTransp) {
// mkl works with following
// [M,K] x [N,K]^T + [N] = [M,N]
const auto xRank = x->rankOf();
// [M,K] x [K,N] = [M,N]
const int M = x->sizeAt(0);
const int K = x->sizeAt(1); // K == wK
const int N = dLdz->sizeAt(1);
// input dims
dnnl::memory::dims xShape = dnnl::memory::dims({ M, K });
dnnl::memory::dims wShape = dnnl::memory::dims({ N, K });
dnnl::memory::dims dLdzShape = dnnl::memory::dims({ M, N });
dnnl::memory::dims bShape = dnnl::memory::dims({ N });
// output dims
dnnl::memory::dims dLdxShape = xShape;
dnnl::memory::dims dLdwShape = wShape;
dnnl::memory::format_tag format = dnnl::memory::format_tag::ab;
dnnl::memory::data_type dataType = dnnl::memory::data_type::f32;
// memory descriptors for arrays
// x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, format);
mkldnnUtils::setBlockStrides(x, x_user_md);
// weights
dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any);
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, dataType, format);
if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) {
weights_user_md.data.format_kind = dnnl_blocked; // overrides format
if (bShouldTransp) {
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1);
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0);
}
else {
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0);
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1);
}
}
// bias
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
mkldnnUtils::setBlockStrides(bias, bias_user_md);
// dLdz
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dataType, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dataType, format);
mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
// dLdw
dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, format);
dnnl::memory::desc dLdw_user_md = dnnl::memory::desc(wShape, dataType, format);
if (dLdw->ews() != 1 || dLdw->ordering() != 'c' || bShouldTransp) {
dLdw_user_md.data.format_kind = dnnl_blocked; // overrides format
if (bShouldTransp) {
dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(1);
dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(0);
}
else {
dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(0);
dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(1);
}
}
// dLdb
dnnl::memory::desc dLdb_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
dnnl::memory::desc dLdb_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
mkldnnUtils::setBlockStrides(dLdb, dLdb_user_md);
// dLdx
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dataType, format);
mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// forward
// operation primitive description
dnnl::inner_product_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, dLdz_mkl_md);
dnnl::inner_product_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// backprob
// dLdw
auto op_bpdw_desc = inner_product_backward_weights::desc(x_mkl_md, dLdw_mkl_md, dLdb_mkl_md, dLdz_mkl_md);
auto op_bpdw_prim_desc = inner_product_backward_weights::primitive_desc(op_bpdw_desc, engine, op_ff_prim_desc);
// backprob
// dLdx
auto op_bpdx_desc = inner_product_backward_data::desc(dLdx_mkl_md, weights_mkl_md, dLdz_mkl_md);
auto op_bpdx_prim_desc = inner_product_backward_data::primitive_desc(op_bpdx_desc, engine, op_ff_prim_desc);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, dnnl::memory> argsDw, argsDx;
dnnl::stream stream(engine);
// dLdz dw
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdw_prim_desc.diff_dst_desc(), argsDw[DNNL_ARG_DIFF_DST]);
// dLdz - dx
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdx_prim_desc.diff_dst_desc(), argsDx[DNNL_ARG_DIFF_DST]);
// input x for dw
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bpdw_prim_desc.src_desc(), argsDw[DNNL_ARG_SRC]);
// weights - dx
mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_bpdx_prim_desc.weights_desc(), argsDx[DNNL_ARG_WEIGHTS]);
// dLdw
auto dLdw_user_mem = dnnl::memory(dLdw_user_md, engine, dLdw->getBuffer());
const bool dLdwReorder = op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc();
auto dLdw_mkl_mem = dLdwReorder ? dnnl::memory(op_bpdw_prim_desc.diff_weights_desc(), engine) : dLdw_user_mem;
argsDw[DNNL_ARG_DIFF_WEIGHTS] = dLdw_mkl_mem;
// dLdx
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer());
const bool dLdxReorder = op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc();
auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_bpdx_prim_desc.diff_src_desc(), engine) : dLdx_user_mem;
argsDx[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
// dLdb
auto dLdb_user_mem = dnnl::memory(dLdb_user_md, engine, dLdb->getBuffer());
const bool dLdbReorder = op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc();
auto dLdb_mkl_mem = dLdbReorder ? dnnl::memory(op_bpdw_prim_desc.diff_bias_desc(), engine) : dLdb_user_mem;
argsDw[DNNL_ARG_DIFF_BIAS] = dLdb_mkl_mem;
// run calculations dw
dnnl::inner_product_backward_weights(op_bpdw_prim_desc).execute(stream, argsDw);
// run calculations dx
dnnl::inner_product_backward_data(op_bpdx_prim_desc).execute(stream, argsDx);
// reorder outputs if necessary
if (dLdxReorder)
dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem);
if (dLdwReorder)
dnnl::reorder(dLdw_mkl_mem, dLdw_user_mem).execute(stream, dLdw_mkl_mem, dLdw_user_mem);
if (dLdbReorder)
dnnl::reorder(dLdb_mkl_mem, dLdb_user_mem).execute(stream, dLdb_mkl_mem, dLdb_user_mem);
stream.wait();
}
PLATFORM_IMPL(xw_plus_b, ENGINE_CPU) {
auto x = INPUT_VARIABLE(0);
auto w = INPUT_VARIABLE(1);
auto b = INPUT_VARIABLE(2);
auto z = OUTPUT_VARIABLE(0);
if (x->isEmpty() || w->isEmpty() || b->isEmpty())
return Status::OK();
const int xRank = x->rankOf();
const int wRank = w->rankOf();
const int zRank = z->rankOf();
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
REQUIRE_TRUE(xRank == 2, 0, "xw_plus_b MKL: Input x array should have rank equal 2, but got instead %i!", xRank);
REQUIRE_TRUE(wRank == 2, 0, "xw_plus_b MKL: Input weights array should have rank equal 2, but got instead %i!", wRank);
REQUIRE_TRUE(zRank == 2, 0, "xw_plus_b MKL: Output array should have rank equal 2, but got instead %i!", zRank);
REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b MKL: Input bias vector should be 1D and have proper dimension 1x%i."
" But got rank %i, and got length %i instead %i.", z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1));
// mkldnnInerPorductss
xwPlusBiasMKLDNN(x, w, b, z, bShouldTransp);
return Status::OK();
}
PLATFORM_CHECK(xw_plus_b, ENGINE_CPU) {
auto x = INPUT_VARIABLE(0);
auto w = INPUT_VARIABLE(1);
auto b = INPUT_VARIABLE(2);
auto z = OUTPUT_VARIABLE(0);
const DataType xType = x->dataType();
const DataType wType = w->dataType();
const DataType bType = b->dataType();
const DataType zType = z->dataType();
/*
Source Weights Destination Bias
f32 f32 f32 f32
u8, s8 s8 u8, s8, s32, f32 u8, s8, s32, f32
*/
return block.isUseMKLDNN() &&
((xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && bType == DataType::FLOAT32 && zType == DataType::FLOAT32) ||
( // x
(xType == DataType::UINT8 || xType == DataType::INT8) &&
// w
(wType == DataType::UINT8 || wType == DataType::INT8) &&
// b
(bType == DataType::UINT8 || bType == DataType::INT8 || bType == DataType::INT32 || bType == DataType::FLOAT32) &&
// z
(zType == DataType::UINT8 || zType == DataType::INT8 || zType == DataType::INT32 || zType == DataType::FLOAT32)
));
}
PLATFORM_IMPL(xw_plus_b_bp, ENGINE_CPU) {
auto x = INPUT_VARIABLE(0);
auto w = INPUT_VARIABLE(1);
auto b = INPUT_VARIABLE(2);
auto dLdz = INPUT_VARIABLE(3);
auto dLdx = OUTPUT_VARIABLE(0);
auto dLdw = OUTPUT_VARIABLE(1);
auto dLdb = OUTPUT_VARIABLE(2);
if (x->isEmpty() || w->isEmpty() || b->isEmpty() || dLdz->isEmpty())
return Status::OK();
const int xRank = x->rankOf();
const int wRank = w->rankOf();
const int dLdzRank = dLdz->rankOf();
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP MKL: Input x array should have rank equal 2, but got instead %i!", x->rankOf());
REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP MKL: Input weights array should have rank equal 2, but got instead %i!", w->rankOf());
REQUIRE_TRUE(dLdz->rankOf() == 2, 0, "xw_plus_b BP MKL: Output array should have rank equal 2, but got instead %i!", dLdz->rankOf());
REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(1), 0, "xw_plus_b BP MKL: Input bias vector should be 1D and have proper dimension 1x%i."
" But got rank %i, and got length %i instead %i.", dLdz->sizeAt(1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(1));
xwPlusBiasBp(x, w, b, dLdz, dLdx, dLdw, dLdb, bShouldTransp);
return Status::OK();
}
PLATFORM_CHECK(xw_plus_b_bp, ENGINE_CPU) {
auto x = INPUT_VARIABLE(0);
auto w = INPUT_VARIABLE(1);
auto b = INPUT_VARIABLE(2);
auto dLdz = INPUT_VARIABLE(3);
auto dLdx = OUTPUT_VARIABLE(0);
auto dLdw = OUTPUT_VARIABLE(1);
auto dLdb = OUTPUT_VARIABLE(2);
const DataType xType = x->dataType();
const DataType wType = w->dataType();
const DataType bType = b->dataType();
const DataType dLdzType = dLdz->dataType();
const DataType dLdxType = dLdx->dataType();
const DataType dLdwType = dLdw->dataType();
const DataType dLdbType = dLdb->dataType();
/*
Source Weights Destination Bias
f32 f32 f32 f32
*/
return block.isUseMKLDNN() &&
(xType == DataType::FLOAT32 && wType == DataType::FLOAT32 &&
bType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 &&
dLdbType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32 &&
dLdwType == DataType::FLOAT32);
}
}
}
}

View File

@ -15,9 +15,9 @@
******************************************************************************/
//
// @author raver119@gmail.com
//
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
@ -45,19 +45,19 @@ TEST_F(DeclarableOpsTests18, test_bitcast_1) {
auto e = NDArrayFactory::create<Nd4jLong>(4597464930322771456L);
sd::ops::bitcast op;
auto status = op.execute({&x}, {&z}, {}, {(Nd4jLong) sd::DataType::INT64}, {});
auto status = op.execute({ &x }, { &z }, {}, { (Nd4jLong)sd::DataType::INT64 }, {});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests18, test_tanh_1) {
auto x = NDArrayFactory::create<float>('c', {8}, {0.23f, -0.23f, 0.35f, -0.35f, 0.64f, -0.64f, 100000.f, -100000.f});
auto x = NDArrayFactory::create<float>('c', { 8 }, { 0.23f, -0.23f, 0.35f, -0.35f, 0.64f, -0.64f, 100000.f, -100000.f });
auto z = x.ulike();
auto e = NDArrayFactory::create<float>('c', {8}, {0.226028f, -0.226028f, 0.336376f, -0.336376f, 0.564900f, -0.564900f, 1.f, -1.f});
auto e = NDArrayFactory::create<float>('c', { 8 }, { 0.226028f, -0.226028f, 0.336376f, -0.336376f, 0.564900f, -0.564900f, 1.f, -1.f });
sd::ops::tanh op;
op.execute({&x}, {&z});
op.execute({ &x }, { &z });
ASSERT_EQ(e, z);
}
@ -187,6 +187,197 @@ TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST3) {
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(output.equalsTo(exp));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests18, XWPlusB_Bp_1) {
auto x = NDArrayFactory::create<float>('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
auto w = NDArrayFactory::create<float>('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
auto b = NDArrayFactory::create<float>({ 100.f, 200.f });
NDArray dLdz('c', { 2, 2 }, DataType::FLOAT32);
dLdz.linspace(1);
sd::ops::xw_plus_b_bp op;
auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto dLdx = result.at(0);
auto dLdw = result.at(1);
auto dLdb = result.at(2);
auto edLdx = NDArrayFactory::create<float>('c', { 2,3 }, { 17.f, 14.f, 10.f, 45.f, 32.f, 26.f });
auto edLdw = NDArrayFactory::create<float>('c', { 3,2 }, { 43.f, 58.f, 26.f, 42.f, 21.f, 30.f });
auto edLdb = NDArrayFactory::create<float>('c', { 2 }, { 4.f, 6.f });
ASSERT_TRUE(edLdx.isSameShape(dLdx));
ASSERT_TRUE(edLdw.isSameShape(dLdw));
ASSERT_TRUE(edLdb.isSameShape(dLdb));
ASSERT_TRUE(edLdx.equalsTo(dLdx));
ASSERT_TRUE(edLdw.equalsTo(dLdw));
ASSERT_TRUE(edLdb.equalsTo(dLdb));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests18, XWPlusB_Bp_2) {
auto x = NDArrayFactory::create<float>('c', { 6,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
auto w = NDArrayFactory::create<float>('c', { 3,4 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f, 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
auto b = NDArrayFactory::create<float>('c', { 4 }, { 100.f, 200.f, 100.f, 200.f });
NDArray dLdz('c', { 6, 4 }, DataType::FLOAT32);
dLdz.linspace(.1, .5);
sd::ops::xw_plus_b_bp op;
auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto dLdx = result.at(0);
auto dLdw = result.at(1);
auto dLdb = result.at(2);
auto edLdx = NDArrayFactory::create<float>('c', { 6,3 }, { 15.3f, 18.700001f, 13.2f, 61.299995f, 62.699997f, 47.200001f, 107.299995f, 106.699997f, 81.199997f, 153.299988f, 150.699997f, 115.199997f, 199.300018f, 194.700012f, 149.199997f, 245.300018f, 238.700012f, 183.199997f });
auto edLdw = NDArrayFactory::create<float>('c', { 3,4 }, { 268.5f, 291.f, 313.5f, 336.f, 226.800003f, 250.800003f, 274.799988f, 298.799988f, 146.699997f, 160.199997f, 173.700012f, 187.200012f });
auto edLdb = NDArrayFactory::create<float>('c', { 4 }, { 30.6f, 33.599998f, 36.599998f, 39.599998f });
ASSERT_TRUE(edLdx.isSameShape(dLdx));
ASSERT_TRUE(edLdw.isSameShape(dLdw));
ASSERT_TRUE(edLdb.isSameShape(dLdb));
ASSERT_TRUE(edLdx.equalsTo(dLdx));
ASSERT_TRUE(edLdw.equalsTo(dLdw));
ASSERT_TRUE(edLdb.equalsTo(dLdb));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests18, XWPlusB_Bp_3) {
auto x = NDArrayFactory::create<float>('c', { 1, 2 }, { 1.f, 11.f });
auto w = NDArrayFactory::create<float>('c', { 2, 3 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
auto b = NDArrayFactory::create<float>({ 100.f, 200.f, 300.f });
auto dLdz = NDArrayFactory::create<float>('c', { 1, 3 }, { 166.f, 269.f, 326.f });
sd::ops::xw_plus_b_bp op;
auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto dLdx = result.at(0);
auto dLdw = result.at(1);
auto dLdb = result.at(2);
auto edLdx = NDArrayFactory::create<float>('c', { 1,2 }, { 3937.f, 3096.f });
auto edLdw = NDArrayFactory::create<float>('c', { 2,3 }, { 166.f, 269.f, 326.f, 1826.f, 2959.f, 3586.f });
auto edLdb = NDArrayFactory::create<float>('c', { 3 }, { 166.f, 269.f, 326.f });
ASSERT_TRUE(edLdx.isSameShape(dLdx));
ASSERT_TRUE(edLdw.isSameShape(dLdw));
ASSERT_TRUE(edLdb.isSameShape(dLdb));
ASSERT_TRUE(edLdx.equalsTo(dLdx));
ASSERT_TRUE(edLdw.equalsTo(dLdw));
ASSERT_TRUE(edLdb.equalsTo(dLdb));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests18, XWPlusB_Bp_4) {
auto x = NDArrayFactory::create<float>('c', { 1, 2 }, { 1.f, 11.f });
auto w = NDArrayFactory::create<float>('c', { 2, 1 }, { 11.f, 3.f });
auto b = NDArrayFactory::create<float>('c', { 1 }, { 200.f });
auto dLdz = NDArrayFactory::create<float>('c', { 1,1 }, { 244.f });
sd::ops::xw_plus_b_bp op;
auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto dLdx = result.at(0);
auto dLdw = result.at(1);
auto dLdb = result.at(2);
auto edLdx = NDArrayFactory::create<float>('c', { 1,2 }, { 2684.f, 732.f });
auto edLdw = NDArrayFactory::create<float>('c', { 2,1 }, { 244.f, 2684.f });
auto edLdb = NDArrayFactory::create<float>('c', { 1 }, { 244.f });
ASSERT_TRUE(edLdx.isSameShape(dLdx));
ASSERT_TRUE(edLdw.isSameShape(dLdw));
ASSERT_TRUE(edLdb.isSameShape(dLdb));
ASSERT_TRUE(edLdx.equalsTo(dLdx));
ASSERT_TRUE(edLdw.equalsTo(dLdw));
ASSERT_TRUE(edLdb.equalsTo(dLdb));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests18, XWPlusB_Bp_5) {
auto x = NDArrayFactory::create<float>('f', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
auto w = NDArrayFactory::create<float>('f', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
auto b = NDArrayFactory::create<float>({ 100.f, 200.f });
auto dLdz = NDArrayFactory::create<float>('f', { 2,2 }, { 140.f, 287.f, 233.f, 351.f });
sd::ops::xw_plus_b_bp op;
auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto dLdx = result.at(0);
auto dLdw = result.at(1);
auto dLdb = result.at(2);
auto edLdxC = NDArrayFactory::create<float>('c', { 2,3 }, { 2705.f, 1818.f, 1026.f, 4912.f, 2967.f, 1850.f });
auto edLdwC = NDArrayFactory::create<float>('c', { 3,2 }, { 3297.f, 4094.f, 4438.f, 5613.f, 2422.f, 3271.f });
auto edLdbC = NDArrayFactory::create<float>('c', { 2 }, { 427.f, 584.f });
auto edLdx = NDArrayFactory::create<float>('f', { 2,3 });
auto edLdw = NDArrayFactory::create<float>('f', { 3,2 });
auto edLdb = NDArrayFactory::create<float>('f', { 2 });
edLdx.assign(edLdxC);
edLdw.assign(edLdwC);
edLdb.assign(edLdbC);
ASSERT_TRUE(edLdx.isSameShape(dLdx));
ASSERT_TRUE(edLdw.isSameShape(dLdw));
ASSERT_TRUE(edLdb.isSameShape(dLdb));
ASSERT_TRUE(edLdx.equalsTo(dLdx));
ASSERT_TRUE(edLdw.equalsTo(dLdw));
ASSERT_TRUE(edLdb.equalsTo(dLdb));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests18, XWPlusB_Bp_6) {
auto x = NDArrayFactory::create<float>('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
auto w = NDArrayFactory::create<float>('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
auto b = NDArrayFactory::create<float>({ 100.f, 200.f });
auto dLdz = NDArrayFactory::create<float>('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f });
// mkl-format
w.permutei({ 1,0 });
sd::ops::xw_plus_b_bp op;
auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, { 1 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto dLdx = result.at(0);
auto dLdw = result.at(1);
auto dLdb = result.at(2);
auto edLdx = NDArrayFactory::create<float>('c', { 2,3 }, { 2695.f, 2012.f, 1566.f, 4247.f, 2635.f, 2418.f });
auto edLdwC = NDArrayFactory::create<float>('c', { 3,2 }, { 4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f });
auto edLdb = NDArrayFactory::create<float>('c', { 2 }, { 483.f, 543.f });
auto edLdw = NDArrayFactory::create<float>('c', { 3,2 }, { 4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f });
edLdw.permutei({ 1,0 });
edLdw.assign(edLdwC);
ASSERT_TRUE(edLdx.isSameShape(dLdx));
ASSERT_TRUE(edLdw.isSameShape(dLdw));
ASSERT_TRUE(edLdb.isSameShape(dLdb));
ASSERT_TRUE(edLdx.equalsTo(dLdx));
ASSERT_TRUE(edLdw.equalsTo(dLdw));
ASSERT_TRUE(edLdb.equalsTo(dLdb));
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests18, TestUpdaterSgd1) {

View File

@ -2432,18 +2432,36 @@ TEST_F(DeclarableOpsTests5, ZeroFraction_3) {
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, XWPlusB_1) {
auto x = NDArrayFactory::create<double>('c', {2,3}, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f});
auto y = NDArrayFactory::create<double>('c', {3,2}, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f});
auto b = NDArrayFactory::create<double>({100.f, 200.f});
auto x = NDArrayFactory::create<float>('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
auto y = NDArrayFactory::create<float>('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
auto b = NDArrayFactory::create<float>({ 100.f, 200.f });
auto exp = NDArrayFactory::create<double>('c', {2,2}, {173.f, 264.f, 310.f, 279.f});
auto exp = NDArrayFactory::create<float>('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f });
sd::ops::xw_plus_b op;
auto result = op.evaluate({&x, &y, &b}, {}, {});
auto result = op.evaluate({ &x, &y, &b });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, XWPlusB_2) {
auto x = NDArrayFactory::create<float>('c', { 1, 2 }, { 1.f, 11.f });
auto y = NDArrayFactory::create<float>('c', { 2, 3 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
auto b = NDArrayFactory::create<float>({ 100.f, 200.f, 300.f });
auto exp = NDArrayFactory::create<float>('c', { 1, 3 }, { 166.f, 269.f, 326.f });
sd::ops::xw_plus_b op;
auto result = op.evaluate({ &x, &y, &b }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
@ -2452,9 +2470,107 @@ TEST_F(DeclarableOpsTests5, XWPlusB_1) {
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, XWPlusB_3) {
auto x = NDArrayFactory::create<float>('c', { 1, 2 }, { 1.f, 11.f });
auto y = NDArrayFactory::create<float>('c', { 2, 1 }, { 11.f, 3.f });
auto b = NDArrayFactory::create<float>('c', { 1 }, { 200.f });
auto exp = NDArrayFactory::create<float>('c', { 1,1 }, { 244.f });
sd::ops::xw_plus_b op;
auto result = op.evaluate({ &x, &y, &b });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, XWPlusB_4) {
auto x = NDArrayFactory::create<float>('f', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
auto y = NDArrayFactory::create<float>('f', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
auto b = NDArrayFactory::create<float>({ 100.f, 200.f });
auto exp = NDArrayFactory::create<float>('f', { 2,2 }, { 140.f, 287.f, 233.f, 351.f });
sd::ops::xw_plus_b op;
auto result = op.evaluate({ &x, &y, &b });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, XWPlusB_5) {
auto x = NDArrayFactory::create<float>('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
auto y = NDArrayFactory::create<float>('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
y = y.transpose();
auto b = NDArrayFactory::create<float>({ 100.f, 200.f });
auto exp = NDArrayFactory::create<float>('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f });
sd::ops::xw_plus_b op;
auto result = op.evaluate({ &x, &y, &b }, {}, { 1 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, XWPlusB_6) {
auto x = NDArrayFactory::create<float>('c', { 3, 2 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
auto y = NDArrayFactory::create<float>('c', { 2, 1 }, { 11.f, 3.f });
auto b = NDArrayFactory::create<float>('c', { 1 }, { 100.f });
auto exp = NDArrayFactory::create<float>('c', { 3, 1 }, { 144.f, 175.f, 173.f });
sd::ops::xw_plus_b op;
auto result = op.evaluate({ &x, &y, &b });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, XWPlusB_7) {
auto x = NDArrayFactory::create<float>('c', { 3, 4 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
auto y = NDArrayFactory::create<float>('c', { 4, 5 }, { 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 3.f, 11.f, 3.f, 11.f });
auto b = NDArrayFactory::create<float>('c', { 5 }, { 100.f, 200.f, 300.f, 400.f, 500.f });
auto exp = NDArrayFactory::create<float>('c', { 3, 5 }, { 219.f, 375.f, 531.f, 575.f, 731.f, 217.f, 317.f, 505.f, 517.f, 705.f, 248.f, 396.f, 496.f, 596.f, 696.f });
sd::ops::xw_plus_b op;
auto result = op.evaluate({ &x, &y, &b });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, StopGradient_1) {

View File

@ -77,7 +77,12 @@ TEST_F(MklDnnTests, helpers_includer) {
sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh_bp;
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp });
sd::ops::platforms::PLATFORM_xw_plus_b_ENGINE_CPU xw_plus_b;
sd::ops::platforms::PLATFORM_xw_plus_b_bp_ENGINE_CPU xw_plus_b_bp;
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp, &xw_plus_b, &xw_plus_b_bp });
#endif
}