xw_plus_b mkldnn implementation (#247)
* libnd4j first step of mkldnn for xw_plus_b and test of aurora crash in imageHelper * libnd4j sync folders with master * libnd4j merge master, raw implementation of xw_plus_b on mkldnn, clean up, need testing and adding checks for corresponded input shapes * libnd4j corrections and checks added to xw_plus_b mkl * libnd4j corrected dataType description based on mkl operation description, need more investigation * libnd4j fixe xw_blus_b mkl implementation, need testing Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j two unit tests added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed check input dimensions bug Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libndj4 one more test added to cover different order handling Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added optional int arg support to define weights format, if arg == 1, mkldnn (do not need transpose in mkldnn implementation), else mmul weights format, corrected check points, added unit test Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j merge master Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some improvements to avoid NDArray transpose in xw_plus_b operation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed issues connected with weights rank, also added support of one case based on tf (for mkldnn, cpu, cuda), test case added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added proper handling of empty inputs (all implementations) * libnd4j fixed compilation error * libnd4j several more corrections after conflict solve and fixed typos Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j removed unsupported data types Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j merge master and fixed issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added propagation implementation for xw_plus_b, fixed issue connected with mkl weights data format, avoided data copy in transpose mode, test cases added, manually tested with gradCheck Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one minor fix of double operation declaration Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j minor tests fixes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed build problem, integrate helpers changes Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
29e61579c1
commit
1d004b542a
|
@ -14,10 +14,11 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// xw_plus_b op. Created by GS <george@skymind.io> 31.01.2018
|
// xw_plus_b op. Created by GS <george@skymind.io> 31.01.2018
|
||||||
//
|
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||||
//
|
//
|
||||||
|
//
|
||||||
|
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_xw_plus_b)
|
#if NOT_EXCLUDED(OP_xw_plus_b)
|
||||||
|
@ -29,34 +30,113 @@
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(xw_plus_b, 3, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(xw_plus_b, 3, 1, false, 0, 0) {
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
|
||||||
auto b = INPUT_VARIABLE(2);
|
auto b = INPUT_VARIABLE(2);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(x->rankOf() <= 2 && y->rankOf() <= 2 && z->rankOf() <= 2, 0, "xw_plus_b: Input and Output NDArrays should have rank less or equal to 2");
|
if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty())
|
||||||
REQUIRE_TRUE(b->isVector() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b: Input vector should have proper dimension 1x%i. "
|
return Status::OK();
|
||||||
"But %i != %i.", z->sizeAt(-1), b->lengthOf(), z->sizeAt(-1));
|
|
||||||
|
const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false);
|
||||||
|
|
||||||
|
auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b: Input x array should have rank equal 2, but got instead %i!", x->rankOf());
|
||||||
|
REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b: Input weights array should have rank equal 2, but got instead %i!", w->rankOf());
|
||||||
|
REQUIRE_TRUE(z->rankOf() == 2, 0, "xw_plus_b: Output array should have rank equal 2, but got instead %i!", z->rankOf());
|
||||||
|
|
||||||
|
REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b: Input bias vector should be 1D and have proper dimension 1x%i."
|
||||||
|
" But got rank %i, and got length %i instead %i.", z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1));
|
||||||
|
|
||||||
// multiply x to y
|
// multiply x to y
|
||||||
MmulHelper::mmul(x, y, z, 1.0, 0.0);
|
MmulHelper::mmul(x, w, z, 1.0, 0.0);
|
||||||
|
|
||||||
// adding b vector
|
// adding b vector
|
||||||
z->addiRowVector(*b);
|
z->addiRowVector(*b);
|
||||||
|
|
||||||
|
if (bTranspose)
|
||||||
|
delete w;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(xw_plus_b) {
|
DECLARE_SHAPE_FN(xw_plus_b) {
|
||||||
auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), inputShape->at(1), false, false,
|
|
||||||
ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace());
|
auto weights = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
const int nWeightsFormat = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
|
||||||
|
|
||||||
|
auto weightsShape = (1 == nWeightsFormat) ? ShapeUtils::evalTranspShapeInfo(*weights, block.getWorkspace()) : inputShape->at(1);
|
||||||
|
|
||||||
|
auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), weightsShape, false, false,
|
||||||
|
ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace());
|
||||||
|
|
||||||
return SHAPELIST(CONSTANT(outputShape));
|
return SHAPELIST(CONSTANT(outputShape));
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(xw_plus_b) {
|
DECLARE_TYPES(xw_plus_b) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(sd::DataType::ANY)
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedOutputTypes({ ALL_FLOATS });
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CUSTOM_OP_IMPL(xw_plus_b_bp, 4, 3, false, 0, 0) {
|
||||||
|
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto b = INPUT_VARIABLE(2);
|
||||||
|
auto dLdz = INPUT_VARIABLE(3);
|
||||||
|
|
||||||
|
auto dLdx = OUTPUT_VARIABLE(0);
|
||||||
|
auto dLdb = OUTPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty() || dLdz->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false);
|
||||||
|
|
||||||
|
auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP: Input x array should have rank equal 2, but got instead %i!", x->rankOf());
|
||||||
|
REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP: Input weights array should have rank equal 2, but got instead %i!", w->rankOf());
|
||||||
|
REQUIRE_TRUE(dLdz->rankOf() == 2, 0, "xw_plus_b BP: Output array should have rank equal 2, but got instead %i!", dLdz->rankOf());
|
||||||
|
REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(-1), 0, "xw_plus_b BP: Input bias vector should be 1D and have proper dimension 1x%i."
|
||||||
|
" But got rank %i, and got length %i instead %i.", dLdz->sizeAt(-1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(-1));
|
||||||
|
|
||||||
|
auto dLdw = (bTranspose) ? new NDArray(OUTPUT_VARIABLE(1)->transpose()) : OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
// dLdb
|
||||||
|
dLdb->assign(dLdz->reduceAlongDimension(reduce::Sum, { 0 }));
|
||||||
|
|
||||||
|
matmul_bp mmul_bp;
|
||||||
|
mmul_bp.execute({ x, w, dLdz }, std::vector<NDArray*>{dLdx, dLdw}, {}, {}, {});
|
||||||
|
|
||||||
|
if (bTranspose) {
|
||||||
|
delete w;
|
||||||
|
delete dLdw;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(xw_plus_b_bp) {
|
||||||
|
|
||||||
|
Nd4jLong* xShapeInfo;
|
||||||
|
Nd4jLong* wShapeInfo;
|
||||||
|
Nd4jLong* bShapeInfo;
|
||||||
|
|
||||||
|
COPY_SHAPE(inputShape->at(0), xShapeInfo);
|
||||||
|
COPY_SHAPE(inputShape->at(1), wShapeInfo);
|
||||||
|
COPY_SHAPE(inputShape->at(2), bShapeInfo);
|
||||||
|
|
||||||
|
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(wShapeInfo), CONSTANT(bShapeInfo));
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(xw_plus_b_bp) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
|
->setAllowedOutputTypes({ ALL_FLOATS });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -867,9 +867,12 @@ namespace sd {
|
||||||
* - 2D matrix MxN
|
* - 2D matrix MxN
|
||||||
* - 1D vector with N elements
|
* - 1D vector with N elements
|
||||||
* output value - 2D matrix NxN as multiply of matrixes and add vector
|
* output value - 2D matrix NxN as multiply of matrixes and add vector
|
||||||
|
* Int args:
|
||||||
|
* 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else mmul
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_xw_plus_b)
|
#if NOT_EXCLUDED(OP_xw_plus_b)
|
||||||
DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0);
|
DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0);
|
||||||
|
DECLARE_CUSTOM_OP(xw_plus_b_bp, 4, 3, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -96,6 +96,10 @@ namespace sd {
|
||||||
|
|
||||||
DECLARE_PLATFORM(tanh_bp, ENGINE_CPU);
|
DECLARE_PLATFORM(tanh_bp, ENGINE_CPU);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(xw_plus_b, ENGINE_CPU);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(xw_plus_b_bp, ENGINE_CPU);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,426 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||||
|
//
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <system/platform_boilerplate.h>
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
|
||||||
|
using namespace dnnl;
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
static void xwPlusBiasMKLDNN(const NDArray* x, const NDArray* weights, const NDArray* bias, NDArray* z, const bool bShouldTransp) {
|
||||||
|
|
||||||
|
// mkl works with following
|
||||||
|
// [M,K] x [N,K]^T + [N] = [M,N]
|
||||||
|
const auto xRank = x->rankOf();
|
||||||
|
|
||||||
|
// [M,K] x [K,N] = [M,N]
|
||||||
|
const int M = x->sizeAt(0);
|
||||||
|
const int K = x->sizeAt(1); // K == wK
|
||||||
|
const int N = z->sizeAt(1);
|
||||||
|
|
||||||
|
dnnl::memory::dims xShape = dnnl::memory::dims({ M, K });
|
||||||
|
dnnl::memory::dims wShape = dnnl::memory::dims({ N, K });
|
||||||
|
dnnl::memory::dims zShape = dnnl::memory::dims({ M, N });
|
||||||
|
dnnl::memory::dims bShape = dnnl::memory::dims({ N });
|
||||||
|
|
||||||
|
dnnl::memory::format_tag format = dnnl::memory::format_tag::ab;
|
||||||
|
|
||||||
|
// x type
|
||||||
|
dnnl::memory::data_type xType = dnnl::memory::data_type::f32;
|
||||||
|
if (x->dataType() == DataType::UINT8)
|
||||||
|
xType = dnnl::memory::data_type::u8;
|
||||||
|
else if (x->dataType() == DataType::INT8)
|
||||||
|
xType = dnnl::memory::data_type::s8;
|
||||||
|
|
||||||
|
// weights type
|
||||||
|
dnnl::memory::data_type wType = (weights->dataType() == DataType::FLOAT32) ?
|
||||||
|
wType = dnnl::memory::data_type::f32 : wType = dnnl::memory::data_type::s8;
|
||||||
|
|
||||||
|
// bias type need add description for bias
|
||||||
|
dnnl::memory::data_type bType = dnnl::memory::data_type::f32;
|
||||||
|
if (bias->dataType() == DataType::INT32)
|
||||||
|
bType = dnnl::memory::data_type::s32;
|
||||||
|
else if (bias->dataType() == DataType::UINT8)
|
||||||
|
bType = dnnl::memory::data_type::u8;
|
||||||
|
else if (bias->dataType() == DataType::INT8)
|
||||||
|
bType = dnnl::memory::data_type::s8;
|
||||||
|
|
||||||
|
// z type
|
||||||
|
dnnl::memory::data_type zType = dnnl::memory::data_type::f32;
|
||||||
|
if (z->dataType() == DataType::INT32)
|
||||||
|
zType = dnnl::memory::data_type::s32;
|
||||||
|
else if (z->dataType() == DataType::UINT8)
|
||||||
|
zType = dnnl::memory::data_type::u8;
|
||||||
|
else if (z->dataType() == DataType::INT8)
|
||||||
|
zType = dnnl::memory::data_type::s8;
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
// x
|
||||||
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
|
||||||
|
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||||
|
|
||||||
|
// weights
|
||||||
|
dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, wType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, wType, format);
|
||||||
|
if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) {
|
||||||
|
|
||||||
|
weights_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
if (bShouldTransp) {
|
||||||
|
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1);
|
||||||
|
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0);
|
||||||
|
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// bias
|
||||||
|
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x);
|
||||||
|
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x);
|
||||||
|
mkldnnUtils::setBlockStrides(bias, bias_user_md);
|
||||||
|
|
||||||
|
// z
|
||||||
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format);
|
||||||
|
mkldnnUtils::setBlockStrides(z, z_user_md);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// operation primitive description
|
||||||
|
dnnl::inner_product_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, z_mkl_md);
|
||||||
|
|
||||||
|
dnnl::inner_product_forward::primitive_desc op_prim_desc(op_desc, engine);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, dnnl::memory> args;
|
||||||
|
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
|
// input
|
||||||
|
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
|
// weights
|
||||||
|
mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
|
// bias
|
||||||
|
auto bias_mkl_mem = dnnl::memory(bias_mkl_md, engine, bias->getBuffer());
|
||||||
|
args[DNNL_ARG_BIAS] = bias_mkl_mem;
|
||||||
|
|
||||||
|
// z
|
||||||
|
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
|
||||||
|
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||||
|
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||||
|
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||||
|
|
||||||
|
// run calculations
|
||||||
|
dnnl::inner_product_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder outputs if necessary
|
||||||
|
if (zReorder)
|
||||||
|
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
static void xwPlusBiasBp(const NDArray* x, const NDArray* weights, const NDArray* bias, const NDArray* dLdz,
|
||||||
|
NDArray* dLdx, NDArray* dLdw, NDArray* dLdb, const bool bShouldTransp) {
|
||||||
|
|
||||||
|
// mkl works with following
|
||||||
|
// [M,K] x [N,K]^T + [N] = [M,N]
|
||||||
|
const auto xRank = x->rankOf();
|
||||||
|
|
||||||
|
// [M,K] x [K,N] = [M,N]
|
||||||
|
const int M = x->sizeAt(0);
|
||||||
|
const int K = x->sizeAt(1); // K == wK
|
||||||
|
const int N = dLdz->sizeAt(1);
|
||||||
|
// input dims
|
||||||
|
dnnl::memory::dims xShape = dnnl::memory::dims({ M, K });
|
||||||
|
dnnl::memory::dims wShape = dnnl::memory::dims({ N, K });
|
||||||
|
dnnl::memory::dims dLdzShape = dnnl::memory::dims({ M, N });
|
||||||
|
|
||||||
|
dnnl::memory::dims bShape = dnnl::memory::dims({ N });
|
||||||
|
// output dims
|
||||||
|
dnnl::memory::dims dLdxShape = xShape;
|
||||||
|
dnnl::memory::dims dLdwShape = wShape;
|
||||||
|
|
||||||
|
dnnl::memory::format_tag format = dnnl::memory::format_tag::ab;
|
||||||
|
dnnl::memory::data_type dataType = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
// x
|
||||||
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, format);
|
||||||
|
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||||
|
|
||||||
|
// weights
|
||||||
|
dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, dataType, format);
|
||||||
|
if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) {
|
||||||
|
|
||||||
|
weights_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
if (bShouldTransp) {
|
||||||
|
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1);
|
||||||
|
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0);
|
||||||
|
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// bias
|
||||||
|
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
|
||||||
|
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
|
||||||
|
mkldnnUtils::setBlockStrides(bias, bias_user_md);
|
||||||
|
|
||||||
|
// dLdz
|
||||||
|
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dataType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dataType, format);
|
||||||
|
mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
|
||||||
|
|
||||||
|
// dLdw
|
||||||
|
dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, format);
|
||||||
|
dnnl::memory::desc dLdw_user_md = dnnl::memory::desc(wShape, dataType, format);
|
||||||
|
if (dLdw->ews() != 1 || dLdw->ordering() != 'c' || bShouldTransp) {
|
||||||
|
|
||||||
|
dLdw_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
if (bShouldTransp) {
|
||||||
|
dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(1);
|
||||||
|
dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(0);
|
||||||
|
dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dLdb
|
||||||
|
dnnl::memory::desc dLdb_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
|
||||||
|
dnnl::memory::desc dLdb_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
|
||||||
|
mkldnnUtils::setBlockStrides(dLdb, dLdb_user_md);
|
||||||
|
|
||||||
|
// dLdx
|
||||||
|
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dataType, format);
|
||||||
|
mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
// forward
|
||||||
|
// operation primitive description
|
||||||
|
dnnl::inner_product_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, dLdz_mkl_md);
|
||||||
|
dnnl::inner_product_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||||
|
|
||||||
|
// backprob
|
||||||
|
// dLdw
|
||||||
|
auto op_bpdw_desc = inner_product_backward_weights::desc(x_mkl_md, dLdw_mkl_md, dLdb_mkl_md, dLdz_mkl_md);
|
||||||
|
auto op_bpdw_prim_desc = inner_product_backward_weights::primitive_desc(op_bpdw_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// backprob
|
||||||
|
// dLdx
|
||||||
|
auto op_bpdx_desc = inner_product_backward_data::desc(dLdx_mkl_md, weights_mkl_md, dLdz_mkl_md);
|
||||||
|
auto op_bpdx_prim_desc = inner_product_backward_data::primitive_desc(op_bpdx_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, dnnl::memory> argsDw, argsDx;
|
||||||
|
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
// dLdz dw
|
||||||
|
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdw_prim_desc.diff_dst_desc(), argsDw[DNNL_ARG_DIFF_DST]);
|
||||||
|
|
||||||
|
// dLdz - dx
|
||||||
|
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdx_prim_desc.diff_dst_desc(), argsDx[DNNL_ARG_DIFF_DST]);
|
||||||
|
|
||||||
|
// input x for dw
|
||||||
|
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bpdw_prim_desc.src_desc(), argsDw[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
|
// weights - dx
|
||||||
|
mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_bpdx_prim_desc.weights_desc(), argsDx[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
|
// dLdw
|
||||||
|
auto dLdw_user_mem = dnnl::memory(dLdw_user_md, engine, dLdw->getBuffer());
|
||||||
|
const bool dLdwReorder = op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc();
|
||||||
|
auto dLdw_mkl_mem = dLdwReorder ? dnnl::memory(op_bpdw_prim_desc.diff_weights_desc(), engine) : dLdw_user_mem;
|
||||||
|
argsDw[DNNL_ARG_DIFF_WEIGHTS] = dLdw_mkl_mem;
|
||||||
|
|
||||||
|
// dLdx
|
||||||
|
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer());
|
||||||
|
const bool dLdxReorder = op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc();
|
||||||
|
auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_bpdx_prim_desc.diff_src_desc(), engine) : dLdx_user_mem;
|
||||||
|
argsDx[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
|
||||||
|
|
||||||
|
// dLdb
|
||||||
|
auto dLdb_user_mem = dnnl::memory(dLdb_user_md, engine, dLdb->getBuffer());
|
||||||
|
const bool dLdbReorder = op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc();
|
||||||
|
auto dLdb_mkl_mem = dLdbReorder ? dnnl::memory(op_bpdw_prim_desc.diff_bias_desc(), engine) : dLdb_user_mem;
|
||||||
|
argsDw[DNNL_ARG_DIFF_BIAS] = dLdb_mkl_mem;
|
||||||
|
|
||||||
|
// run calculations dw
|
||||||
|
dnnl::inner_product_backward_weights(op_bpdw_prim_desc).execute(stream, argsDw);
|
||||||
|
// run calculations dx
|
||||||
|
dnnl::inner_product_backward_data(op_bpdx_prim_desc).execute(stream, argsDx);
|
||||||
|
|
||||||
|
// reorder outputs if necessary
|
||||||
|
if (dLdxReorder)
|
||||||
|
dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem);
|
||||||
|
|
||||||
|
if (dLdwReorder)
|
||||||
|
dnnl::reorder(dLdw_mkl_mem, dLdw_user_mem).execute(stream, dLdw_mkl_mem, dLdw_user_mem);
|
||||||
|
|
||||||
|
if (dLdbReorder)
|
||||||
|
dnnl::reorder(dLdb_mkl_mem, dLdb_user_mem).execute(stream, dLdb_mkl_mem, dLdb_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_IMPL(xw_plus_b, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto w = INPUT_VARIABLE(1);
|
||||||
|
auto b = INPUT_VARIABLE(2);
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (x->isEmpty() || w->isEmpty() || b->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
const int xRank = x->rankOf();
|
||||||
|
const int wRank = w->rankOf();
|
||||||
|
const int zRank = z->rankOf();
|
||||||
|
|
||||||
|
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
|
||||||
|
|
||||||
|
REQUIRE_TRUE(xRank == 2, 0, "xw_plus_b MKL: Input x array should have rank equal 2, but got instead %i!", xRank);
|
||||||
|
REQUIRE_TRUE(wRank == 2, 0, "xw_plus_b MKL: Input weights array should have rank equal 2, but got instead %i!", wRank);
|
||||||
|
REQUIRE_TRUE(zRank == 2, 0, "xw_plus_b MKL: Output array should have rank equal 2, but got instead %i!", zRank);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b MKL: Input bias vector should be 1D and have proper dimension 1x%i."
|
||||||
|
" But got rank %i, and got length %i instead %i.", z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1));
|
||||||
|
|
||||||
|
// mkldnnInerPorductss
|
||||||
|
xwPlusBiasMKLDNN(x, w, b, z, bShouldTransp);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(xw_plus_b, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto w = INPUT_VARIABLE(1);
|
||||||
|
auto b = INPUT_VARIABLE(2);
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const DataType xType = x->dataType();
|
||||||
|
const DataType wType = w->dataType();
|
||||||
|
const DataType bType = b->dataType();
|
||||||
|
const DataType zType = z->dataType();
|
||||||
|
|
||||||
|
/*
|
||||||
|
Source Weights Destination Bias
|
||||||
|
f32 f32 f32 f32
|
||||||
|
u8, s8 s8 u8, s8, s32, f32 u8, s8, s32, f32
|
||||||
|
*/
|
||||||
|
return block.isUseMKLDNN() &&
|
||||||
|
((xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && bType == DataType::FLOAT32 && zType == DataType::FLOAT32) ||
|
||||||
|
( // x
|
||||||
|
(xType == DataType::UINT8 || xType == DataType::INT8) &&
|
||||||
|
// w
|
||||||
|
(wType == DataType::UINT8 || wType == DataType::INT8) &&
|
||||||
|
// b
|
||||||
|
(bType == DataType::UINT8 || bType == DataType::INT8 || bType == DataType::INT32 || bType == DataType::FLOAT32) &&
|
||||||
|
// z
|
||||||
|
(zType == DataType::UINT8 || zType == DataType::INT8 || zType == DataType::INT32 || zType == DataType::FLOAT32)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_IMPL(xw_plus_b_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto w = INPUT_VARIABLE(1);
|
||||||
|
auto b = INPUT_VARIABLE(2);
|
||||||
|
auto dLdz = INPUT_VARIABLE(3);
|
||||||
|
|
||||||
|
auto dLdx = OUTPUT_VARIABLE(0);
|
||||||
|
auto dLdw = OUTPUT_VARIABLE(1);
|
||||||
|
auto dLdb = OUTPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
if (x->isEmpty() || w->isEmpty() || b->isEmpty() || dLdz->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
const int xRank = x->rankOf();
|
||||||
|
const int wRank = w->rankOf();
|
||||||
|
const int dLdzRank = dLdz->rankOf();
|
||||||
|
|
||||||
|
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
|
||||||
|
|
||||||
|
REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP MKL: Input x array should have rank equal 2, but got instead %i!", x->rankOf());
|
||||||
|
REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP MKL: Input weights array should have rank equal 2, but got instead %i!", w->rankOf());
|
||||||
|
REQUIRE_TRUE(dLdz->rankOf() == 2, 0, "xw_plus_b BP MKL: Output array should have rank equal 2, but got instead %i!", dLdz->rankOf());
|
||||||
|
REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(1), 0, "xw_plus_b BP MKL: Input bias vector should be 1D and have proper dimension 1x%i."
|
||||||
|
" But got rank %i, and got length %i instead %i.", dLdz->sizeAt(1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(1));
|
||||||
|
|
||||||
|
xwPlusBiasBp(x, w, b, dLdz, dLdx, dLdw, dLdb, bShouldTransp);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(xw_plus_b_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto w = INPUT_VARIABLE(1);
|
||||||
|
auto b = INPUT_VARIABLE(2);
|
||||||
|
auto dLdz = INPUT_VARIABLE(3);
|
||||||
|
|
||||||
|
auto dLdx = OUTPUT_VARIABLE(0);
|
||||||
|
auto dLdw = OUTPUT_VARIABLE(1);
|
||||||
|
auto dLdb = OUTPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
const DataType xType = x->dataType();
|
||||||
|
const DataType wType = w->dataType();
|
||||||
|
const DataType bType = b->dataType();
|
||||||
|
const DataType dLdzType = dLdz->dataType();
|
||||||
|
const DataType dLdxType = dLdx->dataType();
|
||||||
|
const DataType dLdwType = dLdw->dataType();
|
||||||
|
const DataType dLdbType = dLdb->dataType();
|
||||||
|
|
||||||
|
/*
|
||||||
|
Source Weights Destination Bias
|
||||||
|
f32 f32 f32 f32
|
||||||
|
*/
|
||||||
|
return block.isUseMKLDNN() &&
|
||||||
|
(xType == DataType::FLOAT32 && wType == DataType::FLOAT32 &&
|
||||||
|
bType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 &&
|
||||||
|
dLdbType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32 &&
|
||||||
|
dLdwType == DataType::FLOAT32);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -15,9 +15,9 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "testlayers.h"
|
#include "testlayers.h"
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
@ -45,19 +45,19 @@ TEST_F(DeclarableOpsTests18, test_bitcast_1) {
|
||||||
auto e = NDArrayFactory::create<Nd4jLong>(4597464930322771456L);
|
auto e = NDArrayFactory::create<Nd4jLong>(4597464930322771456L);
|
||||||
|
|
||||||
sd::ops::bitcast op;
|
sd::ops::bitcast op;
|
||||||
auto status = op.execute({&x}, {&z}, {}, {(Nd4jLong) sd::DataType::INT64}, {});
|
auto status = op.execute({ &x }, { &z }, {}, { (Nd4jLong)sd::DataType::INT64 }, {});
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests18, test_tanh_1) {
|
TEST_F(DeclarableOpsTests18, test_tanh_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {8}, {0.23f, -0.23f, 0.35f, -0.35f, 0.64f, -0.64f, 100000.f, -100000.f});
|
auto x = NDArrayFactory::create<float>('c', { 8 }, { 0.23f, -0.23f, 0.35f, -0.35f, 0.64f, -0.64f, 100000.f, -100000.f });
|
||||||
auto z = x.ulike();
|
auto z = x.ulike();
|
||||||
auto e = NDArrayFactory::create<float>('c', {8}, {0.226028f, -0.226028f, 0.336376f, -0.336376f, 0.564900f, -0.564900f, 1.f, -1.f});
|
auto e = NDArrayFactory::create<float>('c', { 8 }, { 0.226028f, -0.226028f, 0.336376f, -0.336376f, 0.564900f, -0.564900f, 1.f, -1.f });
|
||||||
|
|
||||||
sd::ops::tanh op;
|
sd::ops::tanh op;
|
||||||
op.execute({&x}, {&z});
|
op.execute({ &x }, { &z });
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
|
@ -187,6 +187,197 @@ TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST3) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
ASSERT_TRUE(output.equalsTo(exp));
|
ASSERT_TRUE(output.equalsTo(exp));
|
||||||
}
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests18, XWPlusB_Bp_1) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
|
||||||
|
auto w = NDArrayFactory::create<float>('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
|
||||||
|
auto b = NDArrayFactory::create<float>({ 100.f, 200.f });
|
||||||
|
|
||||||
|
NDArray dLdz('c', { 2, 2 }, DataType::FLOAT32);
|
||||||
|
dLdz.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::xw_plus_b_bp op;
|
||||||
|
auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto dLdx = result.at(0);
|
||||||
|
auto dLdw = result.at(1);
|
||||||
|
auto dLdb = result.at(2);
|
||||||
|
|
||||||
|
auto edLdx = NDArrayFactory::create<float>('c', { 2,3 }, { 17.f, 14.f, 10.f, 45.f, 32.f, 26.f });
|
||||||
|
auto edLdw = NDArrayFactory::create<float>('c', { 3,2 }, { 43.f, 58.f, 26.f, 42.f, 21.f, 30.f });
|
||||||
|
auto edLdb = NDArrayFactory::create<float>('c', { 2 }, { 4.f, 6.f });
|
||||||
|
|
||||||
|
ASSERT_TRUE(edLdx.isSameShape(dLdx));
|
||||||
|
ASSERT_TRUE(edLdw.isSameShape(dLdw));
|
||||||
|
ASSERT_TRUE(edLdb.isSameShape(dLdb));
|
||||||
|
ASSERT_TRUE(edLdx.equalsTo(dLdx));
|
||||||
|
ASSERT_TRUE(edLdw.equalsTo(dLdw));
|
||||||
|
ASSERT_TRUE(edLdb.equalsTo(dLdb));
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests18, XWPlusB_Bp_2) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 6,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
|
||||||
|
auto w = NDArrayFactory::create<float>('c', { 3,4 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f, 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
|
||||||
|
auto b = NDArrayFactory::create<float>('c', { 4 }, { 100.f, 200.f, 100.f, 200.f });
|
||||||
|
|
||||||
|
NDArray dLdz('c', { 6, 4 }, DataType::FLOAT32);
|
||||||
|
dLdz.linspace(.1, .5);
|
||||||
|
|
||||||
|
sd::ops::xw_plus_b_bp op;
|
||||||
|
auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto dLdx = result.at(0);
|
||||||
|
auto dLdw = result.at(1);
|
||||||
|
auto dLdb = result.at(2);
|
||||||
|
|
||||||
|
auto edLdx = NDArrayFactory::create<float>('c', { 6,3 }, { 15.3f, 18.700001f, 13.2f, 61.299995f, 62.699997f, 47.200001f, 107.299995f, 106.699997f, 81.199997f, 153.299988f, 150.699997f, 115.199997f, 199.300018f, 194.700012f, 149.199997f, 245.300018f, 238.700012f, 183.199997f });
|
||||||
|
auto edLdw = NDArrayFactory::create<float>('c', { 3,4 }, { 268.5f, 291.f, 313.5f, 336.f, 226.800003f, 250.800003f, 274.799988f, 298.799988f, 146.699997f, 160.199997f, 173.700012f, 187.200012f });
|
||||||
|
auto edLdb = NDArrayFactory::create<float>('c', { 4 }, { 30.6f, 33.599998f, 36.599998f, 39.599998f });
|
||||||
|
ASSERT_TRUE(edLdx.isSameShape(dLdx));
|
||||||
|
ASSERT_TRUE(edLdw.isSameShape(dLdw));
|
||||||
|
ASSERT_TRUE(edLdb.isSameShape(dLdb));
|
||||||
|
ASSERT_TRUE(edLdx.equalsTo(dLdx));
|
||||||
|
ASSERT_TRUE(edLdw.equalsTo(dLdw));
|
||||||
|
ASSERT_TRUE(edLdb.equalsTo(dLdb));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests18, XWPlusB_Bp_3) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 1, 2 }, { 1.f, 11.f });
|
||||||
|
auto w = NDArrayFactory::create<float>('c', { 2, 3 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
|
||||||
|
auto b = NDArrayFactory::create<float>({ 100.f, 200.f, 300.f });
|
||||||
|
|
||||||
|
auto dLdz = NDArrayFactory::create<float>('c', { 1, 3 }, { 166.f, 269.f, 326.f });
|
||||||
|
|
||||||
|
sd::ops::xw_plus_b_bp op;
|
||||||
|
auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto dLdx = result.at(0);
|
||||||
|
auto dLdw = result.at(1);
|
||||||
|
auto dLdb = result.at(2);
|
||||||
|
|
||||||
|
auto edLdx = NDArrayFactory::create<float>('c', { 1,2 }, { 3937.f, 3096.f });
|
||||||
|
auto edLdw = NDArrayFactory::create<float>('c', { 2,3 }, { 166.f, 269.f, 326.f, 1826.f, 2959.f, 3586.f });
|
||||||
|
auto edLdb = NDArrayFactory::create<float>('c', { 3 }, { 166.f, 269.f, 326.f });
|
||||||
|
ASSERT_TRUE(edLdx.isSameShape(dLdx));
|
||||||
|
ASSERT_TRUE(edLdw.isSameShape(dLdw));
|
||||||
|
ASSERT_TRUE(edLdb.isSameShape(dLdb));
|
||||||
|
ASSERT_TRUE(edLdx.equalsTo(dLdx));
|
||||||
|
ASSERT_TRUE(edLdw.equalsTo(dLdw));
|
||||||
|
ASSERT_TRUE(edLdb.equalsTo(dLdb));
|
||||||
|
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests18, XWPlusB_Bp_4) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 1, 2 }, { 1.f, 11.f });
|
||||||
|
auto w = NDArrayFactory::create<float>('c', { 2, 1 }, { 11.f, 3.f });
|
||||||
|
auto b = NDArrayFactory::create<float>('c', { 1 }, { 200.f });
|
||||||
|
|
||||||
|
auto dLdz = NDArrayFactory::create<float>('c', { 1,1 }, { 244.f });
|
||||||
|
|
||||||
|
sd::ops::xw_plus_b_bp op;
|
||||||
|
auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto dLdx = result.at(0);
|
||||||
|
auto dLdw = result.at(1);
|
||||||
|
auto dLdb = result.at(2);
|
||||||
|
|
||||||
|
auto edLdx = NDArrayFactory::create<float>('c', { 1,2 }, { 2684.f, 732.f });
|
||||||
|
auto edLdw = NDArrayFactory::create<float>('c', { 2,1 }, { 244.f, 2684.f });
|
||||||
|
auto edLdb = NDArrayFactory::create<float>('c', { 1 }, { 244.f });
|
||||||
|
ASSERT_TRUE(edLdx.isSameShape(dLdx));
|
||||||
|
ASSERT_TRUE(edLdw.isSameShape(dLdw));
|
||||||
|
ASSERT_TRUE(edLdb.isSameShape(dLdb));
|
||||||
|
ASSERT_TRUE(edLdx.equalsTo(dLdx));
|
||||||
|
ASSERT_TRUE(edLdw.equalsTo(dLdw));
|
||||||
|
ASSERT_TRUE(edLdb.equalsTo(dLdb));
|
||||||
|
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests18, XWPlusB_Bp_5) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('f', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
|
||||||
|
auto w = NDArrayFactory::create<float>('f', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
|
||||||
|
auto b = NDArrayFactory::create<float>({ 100.f, 200.f });
|
||||||
|
|
||||||
|
auto dLdz = NDArrayFactory::create<float>('f', { 2,2 }, { 140.f, 287.f, 233.f, 351.f });
|
||||||
|
|
||||||
|
sd::ops::xw_plus_b_bp op;
|
||||||
|
auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto dLdx = result.at(0);
|
||||||
|
auto dLdw = result.at(1);
|
||||||
|
auto dLdb = result.at(2);
|
||||||
|
|
||||||
|
auto edLdxC = NDArrayFactory::create<float>('c', { 2,3 }, { 2705.f, 1818.f, 1026.f, 4912.f, 2967.f, 1850.f });
|
||||||
|
auto edLdwC = NDArrayFactory::create<float>('c', { 3,2 }, { 3297.f, 4094.f, 4438.f, 5613.f, 2422.f, 3271.f });
|
||||||
|
auto edLdbC = NDArrayFactory::create<float>('c', { 2 }, { 427.f, 584.f });
|
||||||
|
|
||||||
|
auto edLdx = NDArrayFactory::create<float>('f', { 2,3 });
|
||||||
|
auto edLdw = NDArrayFactory::create<float>('f', { 3,2 });
|
||||||
|
auto edLdb = NDArrayFactory::create<float>('f', { 2 });
|
||||||
|
|
||||||
|
edLdx.assign(edLdxC);
|
||||||
|
edLdw.assign(edLdwC);
|
||||||
|
edLdb.assign(edLdbC);
|
||||||
|
|
||||||
|
ASSERT_TRUE(edLdx.isSameShape(dLdx));
|
||||||
|
ASSERT_TRUE(edLdw.isSameShape(dLdw));
|
||||||
|
ASSERT_TRUE(edLdb.isSameShape(dLdb));
|
||||||
|
ASSERT_TRUE(edLdx.equalsTo(dLdx));
|
||||||
|
ASSERT_TRUE(edLdw.equalsTo(dLdw));
|
||||||
|
ASSERT_TRUE(edLdb.equalsTo(dLdb));
|
||||||
|
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests18, XWPlusB_Bp_6) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
|
||||||
|
auto w = NDArrayFactory::create<float>('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
|
||||||
|
auto b = NDArrayFactory::create<float>({ 100.f, 200.f });
|
||||||
|
|
||||||
|
auto dLdz = NDArrayFactory::create<float>('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f });
|
||||||
|
|
||||||
|
// mkl-format
|
||||||
|
w.permutei({ 1,0 });
|
||||||
|
|
||||||
|
sd::ops::xw_plus_b_bp op;
|
||||||
|
auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, { 1 });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto dLdx = result.at(0);
|
||||||
|
auto dLdw = result.at(1);
|
||||||
|
auto dLdb = result.at(2);
|
||||||
|
|
||||||
|
auto edLdx = NDArrayFactory::create<float>('c', { 2,3 }, { 2695.f, 2012.f, 1566.f, 4247.f, 2635.f, 2418.f });
|
||||||
|
auto edLdwC = NDArrayFactory::create<float>('c', { 3,2 }, { 4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f });
|
||||||
|
auto edLdb = NDArrayFactory::create<float>('c', { 2 }, { 483.f, 543.f });
|
||||||
|
auto edLdw = NDArrayFactory::create<float>('c', { 3,2 }, { 4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f });
|
||||||
|
edLdw.permutei({ 1,0 });
|
||||||
|
edLdw.assign(edLdwC);
|
||||||
|
|
||||||
|
ASSERT_TRUE(edLdx.isSameShape(dLdx));
|
||||||
|
ASSERT_TRUE(edLdw.isSameShape(dLdw));
|
||||||
|
ASSERT_TRUE(edLdb.isSameShape(dLdb));
|
||||||
|
ASSERT_TRUE(edLdx.equalsTo(dLdx));
|
||||||
|
ASSERT_TRUE(edLdw.equalsTo(dLdw));
|
||||||
|
ASSERT_TRUE(edLdb.equalsTo(dLdb));
|
||||||
|
}
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests18, TestUpdaterSgd1) {
|
TEST_F(DeclarableOpsTests18, TestUpdaterSgd1) {
|
||||||
|
|
||||||
|
|
|
@ -2432,18 +2432,36 @@ TEST_F(DeclarableOpsTests5, ZeroFraction_3) {
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests5, XWPlusB_1) {
|
TEST_F(DeclarableOpsTests5, XWPlusB_1) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2,3}, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f});
|
auto x = NDArrayFactory::create<float>('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
|
||||||
auto y = NDArrayFactory::create<double>('c', {3,2}, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f});
|
auto y = NDArrayFactory::create<float>('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
|
||||||
auto b = NDArrayFactory::create<double>({100.f, 200.f});
|
auto b = NDArrayFactory::create<float>({ 100.f, 200.f });
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2,2}, {173.f, 264.f, 310.f, 279.f});
|
auto exp = NDArrayFactory::create<float>('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f });
|
||||||
|
|
||||||
sd::ops::xw_plus_b op;
|
sd::ops::xw_plus_b op;
|
||||||
auto result = op.evaluate({&x, &y, &b}, {}, {});
|
auto result = op.evaluate({ &x, &y, &b });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, XWPlusB_2) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 1, 2 }, { 1.f, 11.f });
|
||||||
|
auto y = NDArrayFactory::create<float>('c', { 2, 3 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
|
||||||
|
auto b = NDArrayFactory::create<float>({ 100.f, 200.f, 300.f });
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', { 1, 3 }, { 166.f, 269.f, 326.f });
|
||||||
|
|
||||||
|
sd::ops::xw_plus_b op;
|
||||||
|
auto result = op.evaluate({ &x, &y, &b }, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
@ -2452,9 +2470,107 @@ TEST_F(DeclarableOpsTests5, XWPlusB_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, XWPlusB_3) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 1, 2 }, { 1.f, 11.f });
|
||||||
|
auto y = NDArrayFactory::create<float>('c', { 2, 1 }, { 11.f, 3.f });
|
||||||
|
auto b = NDArrayFactory::create<float>('c', { 1 }, { 200.f });
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', { 1,1 }, { 244.f });
|
||||||
|
|
||||||
|
sd::ops::xw_plus_b op;
|
||||||
|
auto result = op.evaluate({ &x, &y, &b });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto output = result.at(0);
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, XWPlusB_4) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('f', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
|
||||||
|
auto y = NDArrayFactory::create<float>('f', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
|
||||||
|
auto b = NDArrayFactory::create<float>({ 100.f, 200.f });
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('f', { 2,2 }, { 140.f, 287.f, 233.f, 351.f });
|
||||||
|
|
||||||
|
sd::ops::xw_plus_b op;
|
||||||
|
auto result = op.evaluate({ &x, &y, &b });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, XWPlusB_5) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
|
||||||
|
auto y = NDArrayFactory::create<float>('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f });
|
||||||
|
|
||||||
|
y = y.transpose();
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>({ 100.f, 200.f });
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f });
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::xw_plus_b op;
|
||||||
|
auto result = op.evaluate({ &x, &y, &b }, {}, { 1 });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, XWPlusB_6) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 3, 2 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
|
||||||
|
auto y = NDArrayFactory::create<float>('c', { 2, 1 }, { 11.f, 3.f });
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', { 1 }, { 100.f });
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', { 3, 1 }, { 144.f, 175.f, 173.f });
|
||||||
|
|
||||||
|
sd::ops::xw_plus_b op;
|
||||||
|
auto result = op.evaluate({ &x, &y, &b });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, XWPlusB_7) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 3, 4 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f });
|
||||||
|
auto y = NDArrayFactory::create<float>('c', { 4, 5 }, { 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 3.f, 11.f, 3.f, 11.f });
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', { 5 }, { 100.f, 200.f, 300.f, 400.f, 500.f });
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', { 3, 5 }, { 219.f, 375.f, 531.f, 575.f, 731.f, 217.f, 317.f, 505.f, 517.f, 705.f, 248.f, 396.f, 496.f, 596.f, 696.f });
|
||||||
|
|
||||||
|
sd::ops::xw_plus_b op;
|
||||||
|
auto result = op.evaluate({ &x, &y, &b });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
}
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests5, StopGradient_1) {
|
TEST_F(DeclarableOpsTests5, StopGradient_1) {
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,12 @@ TEST_F(MklDnnTests, helpers_includer) {
|
||||||
|
|
||||||
sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh_bp;
|
sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh_bp;
|
||||||
|
|
||||||
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp });
|
sd::ops::platforms::PLATFORM_xw_plus_b_ENGINE_CPU xw_plus_b;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_xw_plus_b_bp_ENGINE_CPU xw_plus_b_bp;
|
||||||
|
|
||||||
|
|
||||||
|
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp, &xw_plus_b, &xw_plus_b_bp });
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
Loading…
Reference in New Issue