Tanh backpropagation mkldnn implementation (#308)
* libnd4j first step of tanh_bp operation implementation on mkldnn Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j optimize several places and added test case for tanh_bp Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j minor corrections and renaming, added one more test case Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j missed mkldnn data format definition Signed-off-by: Oleg <oleg.semeniv@gmail.com>master
parent
e42b4e96c3
commit
e7a995e959
|
@ -94,6 +94,8 @@ namespace sd {
|
|||
|
||||
DECLARE_PLATFORM(tanh, ENGINE_CPU);
|
||||
|
||||
DECLARE_PLATFORM(tanh_bp, ENGINE_CPU);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -109,12 +109,126 @@ namespace sd {
|
|||
const DataType zType = z->dataType();
|
||||
|
||||
const int xRank = x->rankOf();
|
||||
bool bSupportedRanks = xRank < 7;
|
||||
bool bSupportedRanks = !x->isEmpty() && xRank < 7 && (xType == DataType::FLOAT32 && zType == DataType::FLOAT32);
|
||||
/*
|
||||
Source Destination
|
||||
f32 f32
|
||||
*/
|
||||
return !x->isEmpty() && block.isUseMKLDNN() && bSupportedRanks && (xType == DataType::FLOAT32 && zType == DataType::FLOAT32);
|
||||
return block.isUseMKLDNN() && bSupportedRanks;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
static void tanhBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx) {
|
||||
|
||||
const auto xRank = x->rankOf();
|
||||
dnnl::memory::dims xShape, dLdzShape, dLdxShape;
|
||||
|
||||
mkldnnUtils::getDims(x, xRank, xShape);
|
||||
mkldnnUtils::getDims(dLdz, xRank, dLdzShape);
|
||||
mkldnnUtils::getDims(dLdx, xRank, dLdxShape);
|
||||
|
||||
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
|
||||
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
|
||||
|
||||
// dLdz
|
||||
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(dLdz, xRank, dLdz_user_md);
|
||||
|
||||
// dLdx
|
||||
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(dLdx, xRank, dLdx_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// operation primitive description
|
||||
// forward
|
||||
dnnl::eltwise_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, algorithm::eltwise_tanh, x_mkl_md, 0, 0);
|
||||
dnnl::eltwise_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||
|
||||
// backward description
|
||||
dnnl::eltwise_backward::desc op_desc(algorithm::eltwise_tanh, dLdz_mkl_md, x_mkl_md, 0, 0);
|
||||
dnnl::eltwise_backward::primitive_desc op_prim_desc(op_desc, engine, op_ff_prim_desc);
|
||||
|
||||
// provide memory buffers and check whether reorder is required for forward
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
|
||||
// dLdz
|
||||
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, args, dLdz_user_md, op_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST);
|
||||
|
||||
// dLdx
|
||||
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer());
|
||||
const bool dLdxReorder = op_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc();
|
||||
auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_prim_desc.diff_src_desc(), engine) : dLdx_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
|
||||
|
||||
// run calculations backward
|
||||
dnnl::eltwise_backward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (dLdxReorder)
|
||||
dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
|
||||
PLATFORM_IMPL(tanh_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto dLdz = INPUT_VARIABLE(1);
|
||||
auto dLdx = OUTPUT_VARIABLE(0);
|
||||
|
||||
const int rank = input->rankOf();
|
||||
const int dLdzRank = dLdz->rankOf();
|
||||
|
||||
REQUIRE_TRUE(rank <= 6 && dLdzRank <= 6, 0, "TANH_BP_MKLDNN OP: the rank of input and dLdz must be less or qual 6, but got input rank = %i and dLdz rank rank = %i instead !", rank, dLdzRank);
|
||||
|
||||
// mkldnnSoftMax
|
||||
tanhBpMKLDNN(input, dLdz, dLdx);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(tanh_bp, ENGINE_CPU) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto dLdz = INPUT_VARIABLE(1);
|
||||
auto dLdx = OUTPUT_VARIABLE(0);
|
||||
|
||||
const DataType xType = x->dataType();
|
||||
const DataType dLdzType = dLdz->dataType();
|
||||
const DataType dLdxType = dLdx->dataType();
|
||||
|
||||
const int xRank = x->rankOf();
|
||||
const int dLdzRank = dLdz->rankOf();
|
||||
|
||||
bool bSupportedRanks = xRank < 7 && dLdzRank == xRank && (!x->isEmpty() && !dLdz->isEmpty());
|
||||
bSupportedRanks &= (xType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32);
|
||||
|
||||
if (bSupportedRanks) {
|
||||
for (int i = 0; i < xRank; i++) {
|
||||
if (x->sizeAt(i) != dLdz->sizeAt(i)) {
|
||||
bSupportedRanks = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Source Destination
|
||||
//f32 f32
|
||||
return block.isUseMKLDNN() && bSupportedRanks;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -75,6 +75,9 @@ TEST_F(MklDnnTests, helpers_includer) {
|
|||
|
||||
sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh;
|
||||
|
||||
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh });
|
||||
sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh_bp;
|
||||
|
||||
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp });
|
||||
|
||||
#endif
|
||||
}
|
Loading…
Reference in New Issue