Shyrma mkl test (#211)
* - provide nhwc format in mkl conv ops Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections in mkl conv3d Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections in mkl batchnorm Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections in mkl maxpooling2d Signed-off-by: Yurii <iuriish@yahoo.com> * - add format format_tag::any to outputs in mkl conv ops Signed-off-by: Yurii <iuriish@yahoo.com> * - complete corrections in mkl conv ops Signed-off-by: Yurii <iuriish@yahoo.com> * - add test for comparison of execution speeds of mkl conv2d op with different weights format Signed-off-by: Yurii <iuriish@yahoo.com> * - take into account order f in mkl conv ops Signed-off-by: Yurii <iuriish@yahoo.com>master
parent
5ae40f6e38
commit
948646b32d
|
@ -169,8 +169,8 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
|
|||
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
|
@ -178,8 +178,8 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
|
|||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "MAXPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(!isNCDHW) {
|
||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
|
|
|
@ -250,7 +250,7 @@ void pooling3dCUDNN(const LaunchContext* context,
|
|||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: can't set stream for cuDNN", err);
|
||||
printf("fffffffffff\n");
|
||||
|
||||
const int numDims = 5;
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
//
|
||||
// @author saudet
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
|
@ -36,103 +37,44 @@ namespace platforms {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
|
||||
input->rankOf());
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
auto argI = *(block.getIArguments());
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const auto kH = INT_ARG(0);
|
||||
const auto kW = INT_ARG(1);
|
||||
const auto sH = INT_ARG(2);
|
||||
const auto sW = INT_ARG(3);
|
||||
int pH = INT_ARG(4);
|
||||
int pW = INT_ARG(5);
|
||||
auto pH = INT_ARG(4);
|
||||
auto pW = INT_ARG(5);
|
||||
const auto dH = INT_ARG(6);
|
||||
const auto dW = INT_ARG(7);
|
||||
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
||||
const auto paddingMode = INT_ARG(8);
|
||||
const auto extraParam0 = INT_ARG(9);
|
||||
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
|
||||
dH, dW);
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
int oH = 0;
|
||||
int oW = 0;
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||
|
||||
if (!isNCHW) {
|
||||
input = new NDArray(
|
||||
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
output = new NDArray(
|
||||
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
if (isSameMode)
|
||||
if (paddingMode)
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
const int bS = input->sizeAt(0);
|
||||
const int iC = input->sizeAt(1);
|
||||
const int oC = output->sizeAt(1);
|
||||
auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding;
|
||||
|
||||
auto poolingMode = PoolingType::AVG_POOL;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
||||
algorithm,
|
||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||
&user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||
pool_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
auto pool_src_memory = user_src_memory;
|
||||
dnnl::stream stream(engine);
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
auto pool_dst_memory = user_dst_memory;
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory}});
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||
}
|
||||
stream.wait();
|
||||
|
||||
//streams[0].submitAndWait();
|
||||
|
||||
if (!isNCHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
mkldnnUtils::poolingMKLDNN(input, output, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, mode);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
|
@ -141,12 +83,10 @@ PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(
|
||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto gradO = INPUT_VARIABLE(
|
||||
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(
|
||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
|
||||
int kH = INT_ARG(0); // filter(kernel) height
|
||||
int kW = INT_ARG(1); // filter(kernel) width
|
||||
|
@ -156,92 +96,26 @@ PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
|
|||
int pW = INT_ARG(5); // paddings width
|
||||
int dH = INT_ARG(6); // dilations height
|
||||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int extraParam0 = INT_ARG(9);
|
||||
int isNCHW =
|
||||
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
|
||||
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(
|
||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(
|
||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
|
||||
|
||||
if (!isNCHW) {
|
||||
input = new NDArray(input->permute(
|
||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradI = new NDArray(gradI->permute(
|
||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradO = new NDArray(gradO->permute(
|
||||
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
if (isSameMode) // SAME
|
||||
if(paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
auto poolingMode = PoolingType::AVG_POOL;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||
&user_diff_src_md, &user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
||||
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||
pool_padding_r);
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
|
||||
pool_kernel, pool_padding, pool_padding_r);
|
||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||
auto poolB_src_memory = userB_src_memory;
|
||||
dnnl::stream stream(engine);
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
auto poolB_dst_memory = userB_dst_memory;
|
||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||
}
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||
}
|
||||
stream.wait();
|
||||
|
||||
if (!isNCHW) {
|
||||
delete input;
|
||||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding;
|
||||
|
||||
mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, mode);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
//
|
||||
// @author saudet
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
|
@ -29,113 +30,110 @@
|
|||
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(
|
||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto output = OUTPUT_VARIABLE(
|
||||
0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
int kD = INT_ARG(0); // filter(kernel) depth
|
||||
int kH = INT_ARG(1); // filter(kernel) height
|
||||
int kW = INT_ARG(2); // filter(kernel) width
|
||||
int sD = INT_ARG(3); // strides depth
|
||||
int sH = INT_ARG(4); // strides height
|
||||
int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
int dD = INT_ARG(9); // dilations depth
|
||||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) {
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||
"MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
|
||||
input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
||||
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
int kD = INT_ARG(0); // filter(kernel) depth
|
||||
int kH = INT_ARG(1); // filter(kernel) height
|
||||
int kW = INT_ARG(2); // filter(kernel) width
|
||||
int sD = INT_ARG(3); // strides depth
|
||||
int sH = INT_ARG(4); // strides height
|
||||
int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
int dD = INT_ARG(9); // dilations depth
|
||||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
int extraParam0 = INT_ARG(13);
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||
|
||||
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
|
||||
"MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
|
||||
expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
|
||||
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW MKLDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
if (!isNCDHW) {
|
||||
input = new NDArray(
|
||||
input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
output = new NDArray(
|
||||
output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
}
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
if (isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
if(paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding;
|
||||
|
||||
mkldnnUtils::poolingMKLDNN(input, output, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, mode);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(avgpool3dnew, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
|
||||
const int kD = INT_ARG(0); // filter(kernel) depth
|
||||
const int kH = INT_ARG(1); // filter(kernel) height
|
||||
const int kW = INT_ARG(2); // filter(kernel) width
|
||||
const int sD = INT_ARG(3); // strides depth
|
||||
const int sH = INT_ARG(4); // strides height
|
||||
const int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
const int dD = INT_ARG(9); // dilations depth
|
||||
const int dH = INT_ARG(10); // dilations height
|
||||
const int dW = INT_ARG(11); // dilations width
|
||||
const int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging
|
||||
const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
|
||||
if(paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding;
|
||||
|
||||
mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, mode);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
|
||||
|
||||
auto poolingMode = PoolingType::AVG_POOL;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||
extraParam0, true,
|
||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
|
||||
algorithm,
|
||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||
&user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||
pool_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
auto pool_src_memory = user_src_memory;
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
auto pool_dst_memory = user_dst_memory;
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory}});
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||
}
|
||||
stream.wait();
|
||||
|
||||
if (!isNCDHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(avgpool3dnew, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,154 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
#include <ops/declarable/OpRegistrator.h>
|
||||
#include <platform_boilerplate.h>
|
||||
|
||||
#include <helpers/MKLDNNStream.h>
|
||||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(
|
||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto gradO = INPUT_VARIABLE(
|
||||
1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(
|
||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
|
||||
const int kD = INT_ARG(0); // filter(kernel) depth
|
||||
const int kH = INT_ARG(1); // filter(kernel) height
|
||||
const int kW = INT_ARG(2); // filter(kernel) width
|
||||
const int sD = INT_ARG(3); // strides depth
|
||||
const int sH = INT_ARG(4); // strides height
|
||||
const int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
const int dD = INT_ARG(9); // dilations depth
|
||||
const int dH = INT_ARG(10); // dilations height
|
||||
const int dW = INT_ARG(11); // dilations width
|
||||
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||
"MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
||||
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||
{bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||
"MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||
"MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if (!isNCDHW) {
|
||||
input = new NDArray(input->permute(
|
||||
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradI = new NDArray(gradI->permute(
|
||||
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradO = new NDArray(gradO->permute(
|
||||
{0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
}
|
||||
|
||||
if (isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
|
||||
|
||||
auto poolingMode = PoolingType::AVG_POOL;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||
extraParam0, true,
|
||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
|
||||
algorithm,
|
||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||
&user_diff_src_md, &user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
if (input->buffer() == nullptr) {
|
||||
pool_src_md = pool_diff_src_md;
|
||||
user_src_md = user_diff_src_md;
|
||||
}
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
|
||||
pool_kernel, pool_padding, pool_padding_r);
|
||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||
auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||
auto poolB_src_memory = userB_src_memory;
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
auto poolB_dst_memory = userB_dst_memory;
|
||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||
}
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||
}
|
||||
stream.wait();
|
||||
|
||||
if (!isNCDHW) {
|
||||
delete input;
|
||||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -37,12 +37,12 @@ namespace platforms {
|
|||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, const float epsilon, NDArray* z) {
|
||||
static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, NDArray* z,
|
||||
const float epsilon, const bool isNCHW) {
|
||||
|
||||
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any)
|
||||
// also it gives wrong results for formats nhwc and ndhwc
|
||||
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x
|
||||
|
||||
// x -> 2D:nc, 4D:nchw, 5D:ncdhw
|
||||
// x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
|
||||
// mean -> 1D [c]
|
||||
// variance -> 1D [c]
|
||||
// weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta
|
||||
|
@ -50,8 +50,6 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
|
||||
const int xRank = x->rankOf();
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// input type
|
||||
dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||
|
||||
|
@ -63,17 +61,28 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
dnnl::memory::dims dims;
|
||||
dnnl::memory::format_tag format;
|
||||
|
||||
const int indHW = isNCHW ? 2 : 1;
|
||||
const int bS = x->sizeAt(0);
|
||||
const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1);
|
||||
|
||||
int iD, iH, iW;
|
||||
|
||||
if(xRank == 2) {
|
||||
dims = {x->sizeAt(0), x->sizeAt(1)};
|
||||
dims = {bS, iC};
|
||||
format = dnnl::memory::format_tag::nc;
|
||||
}
|
||||
else if(xRank == 4) {
|
||||
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
|
||||
format = dnnl::memory::format_tag::nchw;
|
||||
iH = x->sizeAt(indHW);
|
||||
iW = x->sizeAt(indHW + 1);
|
||||
dims = {bS, iC, iH, iW};
|
||||
format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
}
|
||||
else { // xRank = 5
|
||||
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
|
||||
format = dnnl::memory::format_tag::ncdhw;
|
||||
iD = x->sizeAt(indHW);
|
||||
iH = x->sizeAt(indHW + 1);
|
||||
iW = x->sizeAt(indHW + 2);
|
||||
dims = {bS, iC, iD, iH, iW};
|
||||
format = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
}
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
@ -81,29 +90,34 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
// x
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
|
||||
if(xRank > 2) {
|
||||
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
|
||||
x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3];
|
||||
if(x->ews() != 1 || x->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = x->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = x->strideAt(1);
|
||||
if(xRank > 2) {
|
||||
x_user_md.data.format_desc.blocking.strides[2] = x->strideAt(2);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = x->strideAt(3);
|
||||
}
|
||||
if(xRank > 4)
|
||||
x_user_md.data.format_desc.blocking.strides[4] = x->strideAt(4);
|
||||
}
|
||||
if(xRank > 4)
|
||||
x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
|
||||
|
||||
// z, output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = z->stridesOf()[0];
|
||||
z_user_md.data.format_desc.blocking.strides[1] = z->stridesOf()[1];
|
||||
if(xRank > 2) {
|
||||
z_user_md.data.format_desc.blocking.strides[2] = z->stridesOf()[2];
|
||||
z_user_md.data.format_desc.blocking.strides[3] = z->stridesOf()[3];
|
||||
if(z->ews() != 1 || z->ordering() != 'c') {
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = z->strideAt(0);
|
||||
z_user_md.data.format_desc.blocking.strides[1] = z->strideAt(1);
|
||||
if(xRank > 2) {
|
||||
z_user_md.data.format_desc.blocking.strides[2] = z->strideAt(2);
|
||||
z_user_md.data.format_desc.blocking.strides[3] = z->strideAt(3);
|
||||
}
|
||||
if(xRank > 4)
|
||||
z_user_md.data.format_desc.blocking.strides[4] = z->strideAt(4);
|
||||
}
|
||||
if(xRank > 4)
|
||||
z_user_md.data.format_desc.blocking.strides[4] = z->stridesOf()[4];
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// batchnorm forward description
|
||||
dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
|
||||
|
@ -162,12 +176,11 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights,
|
||||
const float epsilon, NDArray* dLdI, NDArray* dLdW) {
|
||||
NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) {
|
||||
|
||||
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any)
|
||||
// also it gives wrong results for formats nhwc and ndhwc
|
||||
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x
|
||||
|
||||
// x -> 2D:nc, 4D:nchw, 5D:ncdhw
|
||||
// x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
|
||||
// mean -> 1D [c]
|
||||
// variance -> 1D [c]
|
||||
// dLdO - same shape as x
|
||||
|
@ -177,8 +190,6 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
|
||||
const int xRank = x->rankOf();
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// input type
|
||||
dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||
|
||||
|
@ -190,17 +201,28 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
dnnl::memory::dims dims;
|
||||
dnnl::memory::format_tag format;
|
||||
|
||||
const int indHW = isNCHW ? 2 : 1;
|
||||
const int bS = x->sizeAt(0);
|
||||
const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1);
|
||||
|
||||
int iD, iH, iW;
|
||||
|
||||
if(xRank == 2) {
|
||||
dims = {x->sizeAt(0), x->sizeAt(1)};
|
||||
dims = {bS, iC};
|
||||
format = dnnl::memory::format_tag::nc;
|
||||
}
|
||||
else if(xRank == 4) {
|
||||
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
|
||||
format = dnnl::memory::format_tag::nchw;
|
||||
iH = x->sizeAt(indHW);
|
||||
iW = x->sizeAt(indHW + 1);
|
||||
dims = {bS, iC, iH, iW};
|
||||
format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
}
|
||||
else { // xRank = 5
|
||||
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
|
||||
format = dnnl::memory::format_tag::ncdhw;
|
||||
iD = x->sizeAt(indHW);
|
||||
iH = x->sizeAt(indHW + 1);
|
||||
iW = x->sizeAt(indHW + 2);
|
||||
dims = {bS, iC, iD, iH, iW};
|
||||
format = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
}
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
@ -208,41 +230,49 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
// x
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
|
||||
if(xRank > 2) {
|
||||
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
|
||||
x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3];
|
||||
if(x->ews() != 1 || x->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = x->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = x->strideAt(1);
|
||||
if(xRank > 2) {
|
||||
x_user_md.data.format_desc.blocking.strides[2] = x->strideAt(2);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = x->strideAt(3);
|
||||
}
|
||||
if(xRank > 4)
|
||||
x_user_md.data.format_desc.blocking.strides[4] = x->strideAt(4);
|
||||
}
|
||||
if(xRank > 4)
|
||||
x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
|
||||
|
||||
// dLdO
|
||||
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
|
||||
dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->stridesOf()[0];
|
||||
dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->stridesOf()[1];
|
||||
if(xRank > 2) {
|
||||
dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->stridesOf()[2];
|
||||
dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->stridesOf()[3];
|
||||
if(dLdO->ews() != 1 || dLdO->ordering() != 'c') {
|
||||
dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->strideAt(0);
|
||||
dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->strideAt(1);
|
||||
if(xRank > 2) {
|
||||
dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->strideAt(2);
|
||||
dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->strideAt(3);
|
||||
}
|
||||
if(xRank > 4)
|
||||
dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->strideAt(4);
|
||||
}
|
||||
if(xRank > 4)
|
||||
dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->stridesOf()[4];
|
||||
|
||||
// dLdI
|
||||
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
|
||||
dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->stridesOf()[0];
|
||||
dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->stridesOf()[1];
|
||||
if(xRank > 2) {
|
||||
dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->stridesOf()[2];
|
||||
dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->stridesOf()[3];
|
||||
if(dLdI->ews() != 1 || dLdI->ordering() != 'c') {
|
||||
dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->strideAt(0);
|
||||
dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->strideAt(1);
|
||||
if(xRank > 2) {
|
||||
dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->strideAt(2);
|
||||
dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->strideAt(3);
|
||||
}
|
||||
if(xRank > 4)
|
||||
dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->strideAt(4);
|
||||
}
|
||||
if(xRank > 4)
|
||||
dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->stridesOf()[4];
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// batchnorm forward description
|
||||
dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
|
||||
|
@ -331,7 +361,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
// dLdI = dfdm / N + (2/N) * dfdv * (dvdm/2 + (x - m))
|
||||
// dLdI = gamma * ( stdInv * -g_sum/N + (2/N) * dfdv * (dvdm/2 + (x - m)) )
|
||||
|
||||
std::vector<int> axes = {1};
|
||||
std::vector<int> axes = isNCHW ? std::vector<int>{1} : std::vector<int>{xRank - 1};
|
||||
const auto excludedAxes = ShapeUtils::evalDimsToExclude(x->rankOf(), axes);
|
||||
|
||||
// inversed batch size 1 / N
|
||||
|
@ -377,7 +407,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
|
||||
PLATFORM_IMPL(batchnorm, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
||||
auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
|
||||
auto mean = INPUT_VARIABLE(1); // [c]
|
||||
auto variance = INPUT_VARIABLE(2); // [c]
|
||||
NDArray* gamma = nullptr; // [c]
|
||||
|
@ -436,27 +466,19 @@ PLATFORM_IMPL(batchnorm, ENGINE_CPU) {
|
|||
(*weights)({1,2, 0,0}).assign(0);
|
||||
}
|
||||
|
||||
if(axes[0] == inRank - 1 && inRank > 2) { // if nhwc or ndhwc
|
||||
std::vector<int> permut = inRank == 4 ? std::vector<int>({0,3,1,2}) : std::vector<int>({0,4,1,2,3});
|
||||
input = new NDArray(input->permute(permut));
|
||||
output = new NDArray(output->permute(permut));
|
||||
}
|
||||
const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2);
|
||||
|
||||
batchnormMKLDNN(input, mean, variance, weights, epsilon, output);
|
||||
batchnormMKLDNN(input, mean, variance, weights, output, epsilon, isNCHW);
|
||||
|
||||
delete weights;
|
||||
|
||||
if(axes[0] == inRank - 1 && inRank > 2) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(batchnorm, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
|
||||
auto mean = INPUT_VARIABLE(1); // [c]
|
||||
auto variance = INPUT_VARIABLE(2); // [c]
|
||||
NDArray* gamma = nullptr; // [c]
|
||||
|
@ -630,7 +652,7 @@ PLATFORM_CHECK(batchnorm, ENGINE_CPU) {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
|
||||
|
||||
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
||||
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
|
||||
NDArray* mean = INPUT_VARIABLE(1); // [c]
|
||||
NDArray* variance = INPUT_VARIABLE(2); // [c]
|
||||
NDArray* gamma = nullptr; // [c]
|
||||
|
@ -698,15 +720,9 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
|
|||
(*weights)({1,2, 0,0}).assign(0);
|
||||
}
|
||||
|
||||
const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2);
|
||||
|
||||
if(axes[0] == inRank - 1 && inRank > 2) { // if nhwc or ndhwc
|
||||
std::vector<int> permut = inRank == 4 ? std::vector<int>({0,3,1,2}) : std::vector<int>({0,4,1,2,3});
|
||||
input = new NDArray(input->permute(permut));
|
||||
dLdO = new NDArray(dLdO->permute(permut));
|
||||
dLdI = new NDArray(dLdI->permute(permut));
|
||||
}
|
||||
|
||||
batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, epsilon, dLdI, dLdW);
|
||||
batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, dLdI, dLdW, epsilon, isNCHW);
|
||||
|
||||
*dLdM = 0;
|
||||
*dLdV = 0;
|
||||
|
@ -721,17 +737,12 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
|
|||
delete dLdW;
|
||||
}
|
||||
|
||||
if(axes[0] == inRank - 1 && inRank > 2) {
|
||||
delete input;
|
||||
delete dLdO;
|
||||
delete dLdI;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(batchnorm_bp, ENGINE_CPU) {
|
||||
|
||||
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
||||
NDArray* mean = INPUT_VARIABLE(1); // [c]
|
||||
NDArray* variance = INPUT_VARIABLE(2); // [c]
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
//
|
||||
// @author saudet
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
|
@ -33,6 +34,298 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||
const NDArray *bias, NDArray *output,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const int paddingMode, const int isNCHW) {
|
||||
|
||||
// weights [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW]
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||
|
||||
dnnl::memory::dims strides = { sH, sW };
|
||||
dnnl::memory::dims padding = { pH, pW };
|
||||
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||
dnnl::memory::dims dilation = { dH-1, dW-1};
|
||||
|
||||
auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kH, kW};
|
||||
dnnl::memory::dims zDims = {bS, oC, oH, oW};
|
||||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||
}
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
|
||||
// bias
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
if(bias != nullptr)
|
||||
b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x);
|
||||
|
||||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||
if(output->ews() != 1 || output->ordering() != 'c') {
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
|
||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
||||
}
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// operation primitive description
|
||||
dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
|
||||
// weights
|
||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||
if (wReorder)
|
||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||
|
||||
// bias
|
||||
if(bias != nullptr) {
|
||||
auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer());
|
||||
args[DNNL_ARG_BIAS] = b_mkl_mem;
|
||||
}
|
||||
|
||||
// output
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
|
||||
// run calculations
|
||||
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO,
|
||||
NDArray *gradI, NDArray *gradW, NDArray *gradB,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const int paddingMode, const int isNCHW) {
|
||||
|
||||
// weights/gradW [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW]
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||
|
||||
dnnl::memory::dims strides = { sH, sW };
|
||||
dnnl::memory::dims padding = { pH, pW };
|
||||
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||
dnnl::memory::dims dilation = { dH-1, dW-1};
|
||||
|
||||
auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kH, kW};
|
||||
dnnl::memory::dims zDims = {bS, oC, oH, oW};
|
||||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||
}
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
||||
}
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
||||
}
|
||||
|
||||
// gradW
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(2);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
|
||||
|
||||
// gradB
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
if(gradB != nullptr)
|
||||
gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// forward primitive description
|
||||
dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||
|
||||
// backward data primitive description
|
||||
dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
|
||||
|
||||
// backward weights primitive description
|
||||
dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
|
||||
// weights
|
||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||
if (wReorder)
|
||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
if (gradOReorderW)
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
|
||||
if (gradOReorderD)
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
|
||||
// gradW
|
||||
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer());
|
||||
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
||||
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||
|
||||
// gradB
|
||||
if(gradB != nullptr) {
|
||||
auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer());
|
||||
args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem;
|
||||
}
|
||||
|
||||
// run backward data calculations
|
||||
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||
|
||||
if(gradOReorderW || gradOReorderD)
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
|
||||
|
||||
// run backward weights calculations
|
||||
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder gradI if necessary
|
||||
if (gradIReorder)
|
||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
if (gradWReorder)
|
||||
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||
}
|
||||
|
||||
/*
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
static void conv2dMKLDNN(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
|
||||
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
|
||||
|
@ -46,37 +339,37 @@ static void conv2dMKLDNN(nd4j::graph::Context &block, const NDArray *input, cons
|
|||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::desc x_mkl_md(empty), w_mkl_md(empty), b_mkl_md(empty), z_mkl_md(empty);
|
||||
dnnl::memory::desc x_user_md(empty), w_user_md(empty), b_user_md(empty), z_user_md(empty);
|
||||
|
||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||
dnnl::memory::dims strides, padding, padding_r, dilation;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
|
||||
bias, output,
|
||||
&conv_src_md, nullptr, &conv_weights_md, nullptr,
|
||||
&conv_bias_md, &conv_dst_md,
|
||||
&user_src_md, nullptr, &user_weights_md, nullptr,
|
||||
&user_bias_md, &user_dst_md,
|
||||
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
||||
&x_mkl_md, nullptr, &w_mkl_md, nullptr,
|
||||
&b_mkl_md, &z_mkl_md,
|
||||
&x_user_md, nullptr, &w_user_md, nullptr,
|
||||
&b_user_md, &z_user_md,
|
||||
strides, padding, padding_r, dilation);
|
||||
|
||||
auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward,
|
||||
algorithm::convolution_auto, conv_src_md,
|
||||
conv_weights_md, conv_bias_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
conv_padding_r)
|
||||
algorithm::convolution_auto, x_mkl_md,
|
||||
w_mkl_md, b_mkl_md,
|
||||
z_mkl_md, strides, dilation, padding,
|
||||
padding_r)
|
||||
: convolution_forward::desc(prop_kind::forward,
|
||||
algorithm::convolution_auto, conv_src_md,
|
||||
conv_weights_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
conv_padding_r);
|
||||
algorithm::convolution_auto, x_mkl_md,
|
||||
w_mkl_md,
|
||||
z_mkl_md, strides, dilation, padding,
|
||||
padding_r);
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||
auto user_weights_memory = dnnl::memory(user_weights_md, engine,
|
||||
auto user_src_memory = dnnl::memory(x_user_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||
auto user_weights_memory = dnnl::memory(w_user_md, engine,
|
||||
const_cast<NDArray *>(weights)->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
auto user_dst_memory = dnnl::memory(z_user_md, engine, output->buffer());
|
||||
auto conv_src_memory = user_src_memory;
|
||||
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine);
|
||||
|
@ -239,13 +532,16 @@ static void conv2dBpMKLDNN(nd4j::graph::Context &block,
|
|||
}
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(conv2d, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||
|
||||
int sH = INT_ARG(2); // strides height
|
||||
int sW = INT_ARG(3); // strides width
|
||||
|
@ -254,16 +550,28 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) {
|
|||
int dH = INT_ARG(6); // dilations height
|
||||
int dW = INT_ARG(7); // dilations width
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
|
||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
|
||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
|
||||
|
||||
conv2dMKLDNN(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
PLATFORM_CHECK(conv2d, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto weights = INPUT_VARIABLE(1);
|
||||
|
@ -276,10 +584,10 @@ PLATFORM_CHECK(conv2d, ENGINE_CPU) {
|
|||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
|
@ -293,19 +601,33 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
|
|||
int pW = INT_ARG(5); // paddings width
|
||||
int dH = INT_ARG(6); // dilations height
|
||||
int dW = INT_ARG(7); // dilations width
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",input->rankOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",weights->rankOf());
|
||||
REQUIRE_TRUE(gradO->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",gradO->rankOf());
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
conv2dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
|
||||
|
||||
if(paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D_BP MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(conv2d_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
|
|
@ -33,6 +33,314 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||
const NDArray *bias, NDArray *output,
|
||||
const int kD, const int kH, const int kW,
|
||||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW,
|
||||
const int paddingMode, const int isNCDHW) {
|
||||
|
||||
// weights [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW]
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
// const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||
|
||||
dnnl::memory::dims strides = {sD, sH, sW};
|
||||
dnnl::memory::dims padding = {pD, pH, pW};
|
||||
// dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||
dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW};
|
||||
dnnl::memory::dims dilation = {dD-1, dH-1, dW-1};
|
||||
|
||||
auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||
dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
|
||||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
|
||||
}
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
|
||||
|
||||
// bias
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
if(bias != nullptr)
|
||||
b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x);
|
||||
|
||||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||
if(output->ews() != 1 || output->ordering() != 'c') {
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
|
||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
||||
z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(4);
|
||||
}
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// operation primitive description
|
||||
dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
|
||||
// weights
|
||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||
if (wReorder)
|
||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||
|
||||
// bias
|
||||
if(bias != nullptr) {
|
||||
auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer());
|
||||
args[DNNL_ARG_BIAS] = b_mkl_mem;
|
||||
}
|
||||
|
||||
// output
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
|
||||
// run calculations
|
||||
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO,
|
||||
NDArray *gradI, NDArray *gradW, NDArray *gradB,
|
||||
const int kD, const int kH, const int kW,
|
||||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW,
|
||||
const int paddingMode, const int isNCDHW) {
|
||||
|
||||
// weights/gradW [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW]
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
// const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||
|
||||
dnnl::memory::dims strides = {sD, sH, sW};
|
||||
dnnl::memory::dims padding = {pD, pH, pW};
|
||||
// dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||
dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW};
|
||||
dnnl::memory::dims dilation = {dD-1, dH-1, dW-1};
|
||||
|
||||
auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||
dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
|
||||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
|
||||
}
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
||||
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(4);
|
||||
}
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
||||
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(4);
|
||||
}
|
||||
|
||||
// gradW
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2);
|
||||
|
||||
// gradB
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
if(gradB != nullptr)
|
||||
gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// forward primitive description
|
||||
dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||
|
||||
// backward data primitive description
|
||||
dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
|
||||
|
||||
// backward weights primitive description
|
||||
dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||
dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
|
||||
// weights
|
||||
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||
if (wReorder)
|
||||
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
if (gradOReorderW)
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
|
||||
if (gradOReorderD)
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
|
||||
// gradW
|
||||
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer());
|
||||
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
||||
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||
|
||||
// gradB
|
||||
if(gradB != nullptr) {
|
||||
auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer());
|
||||
args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem;
|
||||
}
|
||||
|
||||
// run backward data calculations
|
||||
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||
|
||||
if(gradOReorderW || gradOReorderD)
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
|
||||
|
||||
// run backward weights calculations
|
||||
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder gradI if necessary
|
||||
if (gradIReorder)
|
||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
if (gradWReorder)
|
||||
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
static void conv3dMKLDNN(nd4j::graph::Context &block,
|
||||
const NDArray *input, const NDArray *weights, const NDArray *bias,
|
||||
|
@ -225,6 +533,7 @@ static void conv3dBpMKLDNN(nd4j::graph::Context &block,
|
|||
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory);
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
||||
|
@ -256,15 +565,15 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
|||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
if (paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
conv3dMKLDNN(block, input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
|
||||
conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -280,6 +589,7 @@ PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
@ -318,14 +628,14 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
|||
int trueoD, trueoH, trueoW; // true output depth/height/width
|
||||
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
conv3dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
|
||||
conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -34,17 +34,13 @@ namespace platforms {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const int paddingMode) {
|
||||
const int paddingMode, const bool isNCHW) {
|
||||
|
||||
// input [bS, iC, iH, iW] nchw, mkl doesn't support format nhwc
|
||||
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC]
|
||||
// bias [oC], may be nullptr
|
||||
|
||||
// output [bS, oC, oH, oW] nchw, mkl doesn't support format nhwc
|
||||
// weights [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
|
||||
dnnl::memory::dims strides = { sH, sW };
|
||||
dnnl::memory::dims padding = { pH, pW };
|
||||
|
@ -80,8 +76,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
else
|
||||
zType = dnnl::memory::data_type::s32;
|
||||
|
||||
|
||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
|
@ -93,20 +88,22 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
|
||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||
}
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
|
||||
// bias
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
|
@ -116,11 +113,13 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
|
||||
z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
|
||||
z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
|
||||
z_user_md.data.format_desc.blocking.strides[3] = output->stridesOf()[3];
|
||||
if(output->ews() != 1 || output->ordering() != 'c') {
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
|
||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
||||
}
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -179,21 +178,19 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const int paddingMode) {
|
||||
const int paddingMode, const bool isNCHW) {
|
||||
|
||||
// input and gradI [bS, iC, iH, iW], mkl doesn't support ndhwc format
|
||||
// weights and gradW [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC]
|
||||
// gradB [oC], may be nullptr
|
||||
// gradO [bS, oC, oH, oW]
|
||||
// weights and gradW [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
|
||||
dnnl::memory::dims strides = { sH, sW };
|
||||
dnnl::memory::dims padding = { pH, pW };
|
||||
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||
dnnl::memory::dims dilation = { dH-1, dW-1 };
|
||||
|
||||
// input type
|
||||
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// weights type
|
||||
|
@ -207,7 +204,7 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
|||
// gradB type
|
||||
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
||||
|
||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
|
@ -219,54 +216,59 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
|||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
|
||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||
}
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
|
||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
|
||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
||||
}
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
|
||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
|
||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
||||
}
|
||||
|
||||
// gradW
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3];
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
|
||||
|
||||
// gradB
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
if(gradB != nullptr)
|
||||
gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x);
|
||||
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// forward primitive description
|
||||
|
@ -306,11 +308,15 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
|||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
if (gradOReorder)
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||
const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
if (gradOReorderW)
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
|
||||
if (gradOReorderD)
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||
|
@ -333,6 +339,9 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
|||
// run backward data calculations
|
||||
dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||
|
||||
if(gradOReorderW || gradOReorderD)
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
|
||||
|
||||
// run backward weights calculations
|
||||
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||
|
||||
|
@ -385,23 +394,7 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) {
|
|||
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||
}
|
||||
|
||||
// mkl supports only [oC, iC, kH, kW] format for weights
|
||||
weights = new NDArray(weights->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
|
||||
// mkl supports only NCHW
|
||||
if(!isNCHW) {
|
||||
input = new NDArray(input->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
output = new NDArray(output->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
}
|
||||
|
||||
deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode);
|
||||
|
||||
delete weights;
|
||||
|
||||
if(!isNCHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -477,27 +470,7 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) {
|
|||
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||
}
|
||||
|
||||
// mkl supports only [oC, iC, kH, kW] for weights
|
||||
weights = new NDArray(weights->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
gradW = new NDArray(gradW->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
|
||||
// mkl supports NCHW format only
|
||||
if(!isNCHW) {
|
||||
input = new NDArray(input->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
}
|
||||
|
||||
deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode);
|
||||
|
||||
delete weights;
|
||||
delete gradW;
|
||||
|
||||
if(!isNCHW) {
|
||||
delete input;
|
||||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -33,7 +33,8 @@ namespace platforms {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI,
|
||||
const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const bool isNCHW) {
|
||||
|
||||
// gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format
|
||||
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC]
|
||||
|
@ -51,7 +52,7 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
|||
// gradI type
|
||||
dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
|
||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
|
@ -67,29 +68,32 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
|||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
|
||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
|
||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
||||
}
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
|
||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
|
||||
|
||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
||||
}
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -166,9 +170,9 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
|
|||
|
||||
const int rank = gradO->rankOf();
|
||||
|
||||
REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
||||
REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf());
|
||||
REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF MKLDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
||||
REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF MKLDNN OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf());
|
||||
REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF MKLDNN OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf());
|
||||
|
||||
int indIOioC, indIiH, indWoC(3), indOoH;
|
||||
if(!isNCHW) {
|
||||
|
@ -193,29 +197,29 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
|
|||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
|
||||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
// mkl supports only [oC, iC, kH, kW] for weights
|
||||
weights = new NDArray(weights->permute({3,2,0,1})); // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
// // mkl supports only [oC, iC, kH, kW] for weights
|
||||
// weights = new NDArray(weights->permute({3,2,0,1})); // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
|
||||
// mkl supports NCHW format only
|
||||
if(!isNCHW) {
|
||||
gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
}
|
||||
// // mkl supports NCHW format only
|
||||
// if(!isNCHW) {
|
||||
// gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
// gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
// }
|
||||
|
||||
deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW);
|
||||
deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW);
|
||||
|
||||
delete weights;
|
||||
// delete weights;
|
||||
|
||||
if(!isNCHW) {
|
||||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
// if(!isNCHW) {
|
||||
// delete gradI;
|
||||
// delete gradO;
|
||||
// }
|
||||
|
||||
// ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
|
||||
|
|
|
@ -34,17 +34,14 @@ namespace platforms {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
||||
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
||||
const int pD, const int pH, const int pW, const int dD, const int dH, const int dW,
|
||||
const bool isNCDHW) {
|
||||
|
||||
// input [bS, iD, iH, iW, iC] ncdhw, mkl doesn't support format ndhwc
|
||||
// weights [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
|
||||
// bias [oC], may be nullptr
|
||||
|
||||
// output [bS, oD, oH, oW, oC] ncdhw, mkl doesn't support format ndhwc
|
||||
// weights [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
|
||||
dnnl::memory::dims strides = { sD, sH, sW };
|
||||
dnnl::memory::dims padding = { pD, pH, pW };
|
||||
|
@ -80,8 +77,7 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
else
|
||||
zType = dnnl::memory::data_type::s32;
|
||||
|
||||
|
||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw;
|
||||
dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||
|
@ -93,22 +89,24 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
|
||||
x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
|
||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
|
||||
}
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
|
||||
|
||||
// bias
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
|
@ -118,12 +116,14 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
|
||||
z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
|
||||
z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
|
||||
z_user_md.data.format_desc.blocking.strides[3] = output->stridesOf()[3];
|
||||
z_user_md.data.format_desc.blocking.strides[4] = output->stridesOf()[4];
|
||||
if(output->ews() !=1 || output->ordering() != 'c') {
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
|
||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
||||
z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(4);
|
||||
}
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -184,16 +184,14 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
const int kD, const int kH, const int kW,
|
||||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW) {
|
||||
const int dD, const int dH, const int dW,
|
||||
const bool isNCDHW) {
|
||||
|
||||
// input and gradI [bS, iD, iH, iW, iC], mkl doesn't support ndhwc format
|
||||
// weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
|
||||
// gradB [oC], may be nullptr
|
||||
// gradO [bS, oD, oH, oW, oC]
|
||||
// weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
|
||||
dnnl::memory::dims strides = { sD, sH, sW };
|
||||
dnnl::memory::dims padding = { pD, pH, pW };
|
||||
|
@ -213,7 +211,7 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
// gradB type
|
||||
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
||||
|
||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw; // isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||
|
@ -225,52 +223,58 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
|
||||
x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
|
||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
|
||||
}
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
|
||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
|
||||
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->stridesOf()[4];
|
||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
||||
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(4);
|
||||
}
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
|
||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
|
||||
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->stridesOf()[4];
|
||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
||||
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(4);
|
||||
}
|
||||
|
||||
// gradW
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3];
|
||||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->stridesOf()[4];
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(4);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2);
|
||||
|
||||
// gradB
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
|
@ -317,11 +321,15 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
if (gradOReorder)
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||
const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
if (gradOReorderW)
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
|
||||
if (gradOReorderD)
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||
|
@ -344,6 +352,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
// run backward data calculations
|
||||
dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||
|
||||
if(gradOReorderW || gradOReorderD)
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
|
||||
|
||||
// run backward weights calculations
|
||||
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||
|
||||
|
@ -400,23 +411,7 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) {
|
|||
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
}
|
||||
|
||||
// mkl supports only [oC, iC, kD, kH, kW] format for weights
|
||||
weights = new NDArray(weights->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
|
||||
// mkl supports only NCDHW
|
||||
if(!isNCDHW) {
|
||||
input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
output = new NDArray(output->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||
}
|
||||
|
||||
deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
||||
|
||||
delete weights;
|
||||
|
||||
if(!isNCDHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -495,27 +490,7 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) {
|
|||
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
// mkl supports only [oC, iC, kD, kH, kW] for weights
|
||||
weights = new NDArray(weights->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
gradW = new NDArray(gradW->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
|
||||
// mkl supports NCDHW format only
|
||||
if(!isNCDHW) {
|
||||
input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradI = new NDArray(gradI->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradO = new NDArray(gradO->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||
}
|
||||
|
||||
deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
||||
|
||||
delete weights;
|
||||
delete gradW;
|
||||
|
||||
if(!isNCDHW) {
|
||||
delete input;
|
||||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -86,7 +86,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
else
|
||||
zType = dnnl::memory::data_type::s32;
|
||||
|
||||
dnnl::memory::format_tag xzFrmat = dnnl::memory::format_tag::nchw;
|
||||
dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
|
@ -98,11 +98,13 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format NHWC -> NCHW
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 : 3);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
|
||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); // do permutation NHWC -> NCHW
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||
}
|
||||
|
||||
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
|
@ -122,11 +124,13 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat);
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 : 3);
|
||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1);
|
||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2);
|
||||
if(output->ews() != 1 || output->ordering() != 'c') {
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); // do permutation NHWC -> NCHW
|
||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
||||
}
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -219,7 +223,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
// gradB type
|
||||
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
||||
|
||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
|
@ -230,12 +234,14 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 : 3);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
|
||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||
}
|
||||
|
||||
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
|
@ -249,21 +255,25 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 : 3);
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1);
|
||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFrmat);
|
||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
||||
}
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 : 3);
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1);
|
||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFrmat);
|
||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
||||
}
|
||||
|
||||
// gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||
|
@ -319,11 +329,15 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
if (gradOReorder)
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||
const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
if (gradOReorderW)
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
|
||||
if (gradOReorderD)
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||
|
@ -346,6 +360,9 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
// run backward data calculations
|
||||
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||
|
||||
if(gradOReorderW || gradOReorderD)
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
|
||||
|
||||
// run backward weights calculations
|
||||
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||
|
||||
|
@ -401,6 +418,7 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto weights = INPUT_VARIABLE(1);
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
|
||||
|
@ -473,6 +491,7 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
//
|
||||
// @author saudet
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
|
@ -33,105 +34,38 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
|
||||
input->rankOf());
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
auto argI = *(block.getIArguments());
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const auto kH = INT_ARG(0);
|
||||
const auto kW = INT_ARG(1);
|
||||
const auto sH = INT_ARG(2);
|
||||
const auto sW = INT_ARG(3);
|
||||
int pH = INT_ARG(4);
|
||||
int pW = INT_ARG(5);
|
||||
const auto dH = INT_ARG(6);
|
||||
const auto dW = INT_ARG(7);
|
||||
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D MKLDNN OP: input array should have rank of 4, but got %i instead", input->rankOf());
|
||||
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
|
||||
dH, dW);
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
const int kH = INT_ARG(0);
|
||||
const int kW = INT_ARG(1);
|
||||
const int sH = INT_ARG(2);
|
||||
const int sW = INT_ARG(3);
|
||||
int pH = INT_ARG(4);
|
||||
int pW = INT_ARG(5);
|
||||
const int dH = INT_ARG(6);
|
||||
const int dW = INT_ARG(7);
|
||||
const int paddingMode = INT_ARG(8);
|
||||
// const int extraParam0 = INT_ARG(9);
|
||||
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW
|
||||
|
||||
int oH = 0;
|
||||
int oW = 0;
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||
|
||||
if (!isNCHW) {
|
||||
input = new NDArray(
|
||||
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
output = new NDArray(
|
||||
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
if (isSameMode)
|
||||
if (paddingMode)
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
const int bS = input->sizeAt(0);
|
||||
const int iC = input->sizeAt(1);
|
||||
const int oC = output->sizeAt(1);
|
||||
|
||||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
int extraParam0 = 1;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
||||
algorithm,
|
||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||
&user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||
pool_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
|
||||
auto pool_src_memory = user_src_memory;
|
||||
dnnl::stream stream(engine);
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = user_dst_memory;
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory}});
|
||||
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
|
||||
if (!isNCHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
mkldnnUtils::poolingMKLDNN(input, output, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, algorithm::pooling_max);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -159,117 +93,24 @@ PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) {
|
|||
int pW = INT_ARG(5); // paddings width
|
||||
int dH = INT_ARG(6); // dilations height
|
||||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int extraParam0 = INT_ARG(9);
|
||||
int isNCHW =
|
||||
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
// int extraParam0 = INT_ARG(9);
|
||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
|
||||
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(
|
||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(
|
||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
|
||||
|
||||
if (!isNCHW) {
|
||||
input = new NDArray(input->permute(
|
||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradI = new NDArray(gradI->permute(
|
||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradO = new NDArray(gradO->permute(
|
||||
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
if (isSameMode) // SAME
|
||||
if (paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||
&user_diff_src_md, &user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
// input is sometimes null, so we can't rely on pool_src_md being valid
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
||||
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||
pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
|
||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||
|
||||
auto poolB_src_memory = userB_src_memory;
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
|
||||
auto poolB_dst_memory = userB_dst_memory;
|
||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||
}
|
||||
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto pool_src_memory = user_src_memory;
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
||||
// probably wrong, fix that
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
|
||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
|
||||
if (!isNCHW) {
|
||||
delete input;
|
||||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, algorithm::pooling_max);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
|
@ -34,10 +35,9 @@ namespace platforms {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(
|
||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto output = OUTPUT_VARIABLE(
|
||||
0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||
|
||||
int kD = INT_ARG(0); // filter(kernel) depth
|
||||
int kH = INT_ARG(1); // filter(kernel) height
|
||||
|
@ -51,95 +51,24 @@ PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
|
|||
int dD = INT_ARG(9); // dilations depth
|
||||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||
"MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
|
||||
input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
||||
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
|
||||
"MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
|
||||
expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
|
||||
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
||||
if(paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
if (!isNCDHW) {
|
||||
input = new NDArray(
|
||||
input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
output = new NDArray(
|
||||
output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
}
|
||||
|
||||
if (isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
|
||||
dW);
|
||||
|
||||
|
||||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
auto extraParam0 = 1;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||
extraParam0, true,
|
||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
|
||||
algorithm,
|
||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||
&user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||
pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
|
||||
auto pool_src_memory = user_src_memory;
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = user_dst_memory;
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory}});
|
||||
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
||||
if (!isNCDHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
mkldnnUtils::poolingMKLDNN(input, output, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, algorithm::pooling_max);
|
||||
|
||||
return Status::OK();
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -152,6 +81,7 @@ PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
|
@ -162,127 +92,30 @@ PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
|
|||
const int sD = INT_ARG(3); // strides depth
|
||||
const int sH = INT_ARG(4); // strides height
|
||||
const int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
const int dD = INT_ARG(9); // dilations depth
|
||||
const int dH = INT_ARG(10); // dilations height
|
||||
const int dW = INT_ARG(11); // dilations width
|
||||
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
const int paddngMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||
"MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
||||
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||
{bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||
"MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||
"MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
|
||||
if (!isNCDHW) {
|
||||
input = new NDArray(input->permute(
|
||||
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradI = new NDArray(gradI->permute(
|
||||
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradO = new NDArray(gradO->permute(
|
||||
{0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
}
|
||||
if(paddngMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
if (isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
|
||||
dW);
|
||||
|
||||
|
||||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
auto extraParam0 = 1;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||
extraParam0, true,
|
||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
|
||||
algorithm,
|
||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||
&user_diff_src_md, &user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
// input is sometimes null, so we can't rely on pool_src_md being valid
|
||||
if (input->buffer() == nullptr) {
|
||||
pool_src_md = pool_diff_src_md;
|
||||
user_src_md = user_diff_src_md;
|
||||
}
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
|
||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||
auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||
|
||||
auto poolB_src_memory = userB_src_memory;
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
|
||||
auto poolB_dst_memory = userB_dst_memory;
|
||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||
}
|
||||
|
||||
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
|
||||
auto pool_src_memory = user_src_memory;
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
|
||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
|
||||
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
|
||||
if (!isNCDHW) {
|
||||
delete input;
|
||||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, algorithm::pooling_max);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -16,9 +16,11 @@
|
|||
|
||||
//
|
||||
// @author saudet
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <dnnl_types.h>
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
#include "mkldnnUtils.h"
|
||||
|
||||
using namespace dnnl;
|
||||
|
@ -26,6 +28,314 @@ using namespace dnnl;
|
|||
namespace nd4j {
|
||||
namespace mkldnnUtils {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
void poolingMKLDNN(const NDArray *input, NDArray *output,
|
||||
const int kD, const int kH, const int kW,
|
||||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int isNCHW, const dnnl::algorithm mode) {
|
||||
|
||||
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for input
|
||||
const int rank = input->rankOf();
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;
|
||||
dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims;
|
||||
dnnl::memory::format_tag xzFrmat;
|
||||
|
||||
const auto type = dnnl::memory::data_type::f32;
|
||||
|
||||
if(rank == 4) { // 2d
|
||||
|
||||
ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
strides = { sH, sW };
|
||||
kernel = { kH, kW };
|
||||
padding = { pH, pW };
|
||||
padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||
xDims = {bS, iC, iH, iW};
|
||||
zDims = {bS, oC, oH, oW};
|
||||
|
||||
xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
}
|
||||
else { // 3d
|
||||
|
||||
ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH);
|
||||
|
||||
strides = { sD, sH, sW };
|
||||
kernel = { kD, kH, kW };
|
||||
padding = { pD, pH, pW };
|
||||
padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||
xDims = {bS, iC, iD, iH, iW};
|
||||
zDims = {bS, oC, oD, oH, oW};
|
||||
|
||||
xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
}
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
|
||||
if(rank == 5)
|
||||
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3);
|
||||
}
|
||||
|
||||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||
if(output->ews() != 1 || output->ordering() != 'c') {
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 :-1);
|
||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1);
|
||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2);
|
||||
if(rank == 5)
|
||||
z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(isNCHW ? 4 : 3);
|
||||
}
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// operation primitive description
|
||||
dnnl::pooling_forward::desc op_desc(dnnl::prop_kind::forward_inference, mode, x_mkl_md, z_mkl_md, strides, kernel, padding, padding_r);
|
||||
dnnl::pooling_forward::primitive_desc op_prim_desc(op_desc, engine);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
|
||||
// output
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
|
||||
// run calculations
|
||||
dnnl::pooling_forward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
|
||||
const int kD, const int kH, const int kW,
|
||||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int isNCHW, const dnnl::algorithm mode) {
|
||||
|
||||
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for input
|
||||
|
||||
const int rank = input->rankOf();
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;
|
||||
dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims;
|
||||
dnnl::memory::format_tag xzFrmat;
|
||||
|
||||
const auto type = dnnl::memory::data_type::f32;
|
||||
|
||||
if(rank == 4) { // 2d
|
||||
|
||||
ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
strides = { sH, sW };
|
||||
kernel = { kH, kW };
|
||||
padding = { pH, pW };
|
||||
padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||
xDims = {bS, iC, iH, iW};
|
||||
zDims = {bS, oC, oH, oW};
|
||||
|
||||
xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
}
|
||||
else { // 3d
|
||||
|
||||
ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH);
|
||||
|
||||
strides = { sD, sH, sW };
|
||||
kernel = { kD, kH, kW };
|
||||
padding = { pD, pH, pW };
|
||||
padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||
xDims = {bS, iC, iD, iH, iW};
|
||||
zDims = {bS, oC, oD, oH, oW};
|
||||
|
||||
xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
}
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
|
||||
if(rank == 5)
|
||||
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3);
|
||||
}
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 :-1);
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1);
|
||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2);
|
||||
if(rank == 5)
|
||||
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(isNCHW ? 4 : 3);
|
||||
}
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 :-1);
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1);
|
||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2);
|
||||
if(rank == 5)
|
||||
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(isNCHW ? 4 : 3);
|
||||
}
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// forward primitive description
|
||||
dnnl::pooling_forward::desc op_ff_desc(dnnl::prop_kind::forward, mode, x_mkl_md, gradO_mkl_md, strides, kernel, padding, padding_r);
|
||||
dnnl::pooling_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||
|
||||
// backward primitive description
|
||||
dnnl::pooling_backward::desc op_bp_desc(mode, gradI_mkl_md, gradO_mkl_md, strides, kernel, padding, padding_r);
|
||||
dnnl::pooling_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
const bool gradOReorder = op_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||
if (gradOReorder)
|
||||
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||
const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
|
||||
if(mode == algorithm::pooling_max) {
|
||||
|
||||
// input
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||
const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
|
||||
// z
|
||||
auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine);
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
|
||||
// auxiliary memory allocation
|
||||
auto workspace = dnnl::memory(op_ff_prim_desc.workspace_desc(), engine);
|
||||
args[DNNL_ARG_WORKSPACE] = workspace;
|
||||
|
||||
// run forward calculations
|
||||
dnnl::pooling_forward(op_ff_prim_desc).execute(stream, args);
|
||||
}
|
||||
|
||||
// run backward calculations
|
||||
dnnl::pooling_backward(op_bp_prim_desc).execute(stream, args);
|
||||
|
||||
|
||||
// reorder gradI if necessary
|
||||
if (gradIReorder)
|
||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
||||
const Nd4jLong* shape = src->getShapeInfo();
|
||||
long rank = shape[0];
|
||||
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||
long dim2 = axis >= 2 ? 1 : 2;
|
||||
long dim3 = axis >= 3 ? 2 : 3;
|
||||
dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
auto supposed_to_be_any_format = format; // doesn't work with "any"
|
||||
|
||||
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
|
||||
*lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked;
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
|
||||
*lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked;
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
||||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
|
||||
*lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked;
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
dnnl::engine& getEngine(void *ptr) {
|
||||
auto eng = reinterpret_cast<dnnl::engine*>(ptr);
|
||||
return *eng;
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void getMKLDNNMemoryDescPool2d(
|
||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||
|
@ -307,104 +617,51 @@ void getMKLDNNMemoryDescConv3d(
|
|||
}
|
||||
};
|
||||
|
||||
|
||||
// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
// dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
||||
// dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
||||
// const Nd4jLong* shape = src->getShapeInfo();
|
||||
// Nd4jLong rank = shape[0];
|
||||
// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||
// Nd4jLong dim2 = axis >= 2 ? 1 : 2;
|
||||
// Nd4jLong dim3 = axis >= 3 ? 2 : 3;
|
||||
// dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||
|
||||
// auto type = dnnl::memory::data_type::f32;
|
||||
// auto format = dnnl::memory::format_tag::nchw;
|
||||
// auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||
|
||||
// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
|
||||
// *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_src_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||
// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||
// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||
// user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
||||
// }
|
||||
|
||||
// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
|
||||
// *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
||||
// }
|
||||
|
||||
// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
|
||||
// *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_dst_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||
// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||
// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||
// user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
||||
// }
|
||||
// };
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
||||
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
||||
const Nd4jLong* shape = src->getShapeInfo();
|
||||
long rank = shape[0];
|
||||
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||
long dim2 = axis >= 2 ? 1 : 2;
|
||||
long dim3 = axis >= 3 ? 2 : 3;
|
||||
dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||
Nd4jLong rank = shape[0];
|
||||
Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||
Nd4jLong dim2 = axis >= 2 ? 1 : 2;
|
||||
Nd4jLong dim3 = axis >= 3 ? 2 : 3;
|
||||
dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
auto supposed_to_be_any_format = format; // doesn't work with "any"
|
||||
auto format = dnnl::memory::format_tag::nchw;
|
||||
auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||
|
||||
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
|
||||
*lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked;
|
||||
if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
|
||||
*batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
|
||||
*lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked;
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
|
||||
*batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
||||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
|
||||
*lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked;
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
|
||||
*batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
dnnl::engine& getEngine(void *ptr) {
|
||||
auto eng = reinterpret_cast<dnnl::engine*>(ptr);
|
||||
return *eng;
|
||||
}
|
||||
|
||||
};
|
||||
*/
|
||||
|
||||
}
|
||||
}
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
//
|
||||
// @author saudet
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#ifndef DEV_TESTS_MKLDNNUTILS_H
|
||||
|
@ -81,17 +82,27 @@ namespace nd4j{
|
|||
DECLARE_PLATFORM(deconv3d_bp, ENGINE_CPU);
|
||||
|
||||
DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CPU);
|
||||
|
||||
|
||||
DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU);
|
||||
}
|
||||
}
|
||||
|
||||
namespace mkldnnUtils {
|
||||
|
||||
void poolingMKLDNN(const NDArray *input, NDArray *output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode);
|
||||
|
||||
void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode);
|
||||
|
||||
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
||||
|
||||
dnnl::engine& getEngine(void *ptr);
|
||||
|
||||
/**
|
||||
* Utility methods for MKLDNN
|
||||
*/
|
||||
void getMKLDNNMemoryDescConv2d(
|
||||
/* void getMKLDNNMemoryDescConv2d(
|
||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
|
||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||
|
@ -130,12 +141,7 @@ namespace nd4j{
|
|||
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
||||
|
||||
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
||||
|
||||
dnnl::engine& getEngine(void *ptr);
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2031,121 +2031,6 @@ TEST_F(DeclarableOpsTests1, Sum1) {
|
|||
}
|
||||
*/
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Avgpool2d_test1) {
|
||||
|
||||
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
||||
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
||||
// auto z('c',{bS,iD,oH,oW});
|
||||
|
||||
auto variableSpace = new VariableSpace();
|
||||
variableSpace->putVariable(-1, x);
|
||||
// variableSpace->putVariable(1, &z);
|
||||
|
||||
auto block = new Context(1, variableSpace, false);
|
||||
block->fillInputs({-1});
|
||||
std::vector<int>* argI = block->getIArguments();
|
||||
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
|
||||
nd4j::ops::avgpool2d pooling;
|
||||
Nd4jStatus status = pooling.execute(block);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
|
||||
|
||||
delete variableSpace;
|
||||
delete block;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Avgpool2d_test2) {
|
||||
const int bS = 2;
|
||||
const int iD = 1;
|
||||
const int iH = 28;
|
||||
const int iW = 28;
|
||||
const int kH = 5;
|
||||
const int kW = 5;
|
||||
const int sH = 1;
|
||||
const int sW = 1;
|
||||
const int pH = 0;
|
||||
const int pW = 0;
|
||||
const int dH = 1;
|
||||
const int dW = 1;
|
||||
const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height
|
||||
const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width
|
||||
|
||||
|
||||
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
||||
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
||||
// auto z('c',{bS,iD,oH,oW});
|
||||
|
||||
auto variableSpace = new VariableSpace();
|
||||
variableSpace->putVariable(-1, x);
|
||||
// variableSpace->putVariable(1, &z);
|
||||
|
||||
auto block = new Context(1, variableSpace, false);
|
||||
block->fillInputs({-1});
|
||||
std::vector<int>* argI = block->getIArguments();
|
||||
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
|
||||
nd4j::ops::avgpool2d pooling;
|
||||
Nd4jStatus status = pooling.execute(block);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
||||
// result->printShapeInfo();
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
|
||||
delete variableSpace;
|
||||
delete block;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Avgpool2d_test3) {
|
||||
const int bS = 2;
|
||||
const int iD = 1;
|
||||
const int iH = 28;
|
||||
const int iW = 28;
|
||||
const int kH = 5;
|
||||
const int kW = 5;
|
||||
const int sH = 1;
|
||||
const int sW = 1;
|
||||
const int pH = 0;
|
||||
const int pW = 0;
|
||||
const int dH = 1;
|
||||
const int dW = 1;
|
||||
const int oH = (int) nd4j::math::nd4j_ceil<float, int>(iH * 1.f / sH);
|
||||
const int oW = (int) nd4j::math::nd4j_ceil<float, int>(iW * 1.f / sW);
|
||||
|
||||
|
||||
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
||||
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
||||
// auto z('c',{bS,iD,oH,oW});
|
||||
|
||||
auto variableSpace = new VariableSpace();
|
||||
variableSpace->putVariable(-1, x);
|
||||
// variableSpace->putVariable(1, &z);
|
||||
|
||||
auto block = new Context(1, variableSpace, false);
|
||||
block->fillInputs({-1});
|
||||
std::vector<int>* argI = block->getIArguments();
|
||||
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
|
||||
nd4j::ops::avgpool2d pooling;
|
||||
Nd4jStatus status = pooling.execute(block);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
||||
// result->printShapeInfo();
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
|
||||
delete variableSpace;
|
||||
delete block;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Pnormpool2d1) {
|
||||
|
||||
|
|
|
@ -360,7 +360,6 @@ TEST_F(DeclarableOpsTests4, avgpool2d_12) {
|
|||
917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5,
|
||||
1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5});
|
||||
input.linspace(1.);
|
||||
input.syncToDevice();
|
||||
|
||||
nd4j::ops::avgpool2d op;
|
||||
auto results = op.evaluate({&input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat});
|
||||
|
@ -377,6 +376,160 @@ TEST_F(DeclarableOpsTests4, avgpool2d_12) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests4, avgpool2d_13) {
|
||||
|
||||
const int bS = 2; // batch size
|
||||
const int iD = 1; // input depth (number of picture channels, for example rgb=3)
|
||||
const int iH = 28; // picture height in pixels
|
||||
const int iW = 28; // picture width in pixels
|
||||
const int kH = 5; // kernel height in pixels
|
||||
const int kW = 5; // kernel width in pixels
|
||||
const int sH = 1; // stride step in horizontal direction
|
||||
const int sW = 1; // stride step in vertical direction
|
||||
const int pH = 0; // padding height
|
||||
const int pW = 0; // padding width
|
||||
const int dH = 2; // dilation height
|
||||
const int dW = 2; // dilation width
|
||||
const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height
|
||||
const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width
|
||||
|
||||
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
||||
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
||||
// auto z('c',{bS,iD,oH,oW});
|
||||
|
||||
auto variableSpace = new VariableSpace();
|
||||
variableSpace->putVariable(-1, x);
|
||||
// variableSpace->putVariable(1, &z);
|
||||
|
||||
auto block = new Context(1, variableSpace, false);
|
||||
block->fillInputs({-1});
|
||||
std::vector<int>* argI = block->getIArguments();
|
||||
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
|
||||
nd4j::ops::avgpool2d pooling;
|
||||
Nd4jStatus status = pooling.execute(block);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
|
||||
|
||||
delete variableSpace;
|
||||
delete block;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests4, avgpool2d_14) {
|
||||
const int bS = 2;
|
||||
const int iD = 1;
|
||||
const int iH = 28;
|
||||
const int iW = 28;
|
||||
const int kH = 5;
|
||||
const int kW = 5;
|
||||
const int sH = 1;
|
||||
const int sW = 1;
|
||||
const int pH = 0;
|
||||
const int pW = 0;
|
||||
const int dH = 1;
|
||||
const int dW = 1;
|
||||
const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height
|
||||
const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width
|
||||
|
||||
|
||||
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
||||
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
||||
// auto z('c',{bS,iD,oH,oW});
|
||||
|
||||
auto variableSpace = new VariableSpace();
|
||||
variableSpace->putVariable(-1, x);
|
||||
// variableSpace->putVariable(1, &z);
|
||||
|
||||
auto block = new Context(1, variableSpace, false);
|
||||
block->fillInputs({-1});
|
||||
std::vector<int>* argI = block->getIArguments();
|
||||
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
|
||||
nd4j::ops::avgpool2d pooling;
|
||||
Nd4jStatus status = pooling.execute(block);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
||||
// result->printShapeInfo();
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
|
||||
delete variableSpace;
|
||||
delete block;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests4, Avgpool2d_test15) {
|
||||
const int bS = 2;
|
||||
const int iD = 1;
|
||||
const int iH = 28;
|
||||
const int iW = 28;
|
||||
const int kH = 5;
|
||||
const int kW = 5;
|
||||
const int sH = 1;
|
||||
const int sW = 1;
|
||||
const int pH = 0;
|
||||
const int pW = 0;
|
||||
const int dH = 1;
|
||||
const int dW = 1;
|
||||
const int oH = (int) nd4j::math::nd4j_ceil<float, int>(iH * 1.f / sH);
|
||||
const int oW = (int) nd4j::math::nd4j_ceil<float, int>(iW * 1.f / sW);
|
||||
|
||||
|
||||
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
||||
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
||||
// auto z('c',{bS,iD,oH,oW});
|
||||
|
||||
auto variableSpace = new VariableSpace();
|
||||
variableSpace->putVariable(-1, x);
|
||||
// variableSpace->putVariable(1, &z);
|
||||
|
||||
auto block = new Context(1, variableSpace, false);
|
||||
block->fillInputs({-1});
|
||||
std::vector<int>* argI = block->getIArguments();
|
||||
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
|
||||
nd4j::ops::avgpool2d pooling;
|
||||
Nd4jStatus status = pooling.execute(block);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
||||
// result->printShapeInfo();
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
|
||||
delete variableSpace;
|
||||
delete block;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests4, avgpool2d_16) {
|
||||
|
||||
int bS=2, iH=4,iW=4, iC=2, kH=2,kW=2, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1;
|
||||
int oH=2,oW=2;
|
||||
int paddingMode = 1; // 1-SAME, 0-VALID
|
||||
int dataFormat = 1; // 1-NHWC, 0-NDHW
|
||||
|
||||
NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32);
|
||||
NDArray output('f', {bS, oH, oW, iC}, nd4j::DataType::FLOAT32);
|
||||
NDArray expected('c', {bS, oH, oW, iC}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(1.);
|
||||
|
||||
nd4j::ops::avgpool2d op;
|
||||
auto status = op.execute({&input}, {&output}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
// output.printBuffer();
|
||||
//expected.printIndexedBuffer("expected");
|
||||
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests4, biasadd_1) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 3, 3, 2});
|
||||
|
|
|
@ -422,50 +422,38 @@ TEST_F(PlaygroundTests, my) {
|
|||
delete variableSpace;
|
||||
}
|
||||
|
||||
|
||||
#include<ops/declarable/helpers/batchnorm.h>
|
||||
|
||||
TEST_F(PlaygroundTests, my) {
|
||||
|
||||
const int N = 10000;
|
||||
const Nd4jLong dim0(128), dim1(128), dim2(128);
|
||||
int N = 100;
|
||||
int bS=16, iH=128,iW=128, iC=32,oC=64, kH=4,kW=4, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
||||
int oH=128,oW=128;
|
||||
|
||||
NDArray input('c', {dim0,dim1,dim2}, nd4j::DataType::DOUBLE);
|
||||
NDArray mean('c', {dim1}, nd4j::DataType::DOUBLE);
|
||||
NDArray variance('c', {dim1}, nd4j::DataType::DOUBLE);
|
||||
NDArray gamma('c', {dim1}, nd4j::DataType::DOUBLE);
|
||||
NDArray beta ('c', {dim1}, nd4j::DataType::DOUBLE);
|
||||
int paddingMode = 1; // 1-SAME, 0-VALID;
|
||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||
|
||||
NDArray output('c', {dim0,dim1,dim2}, nd4j::DataType::DOUBLE);
|
||||
// NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32);
|
||||
// NDArray output('c', {bS, oC, oH, oW}, nd4j::DataType::FLOAT32);
|
||||
NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32);
|
||||
NDArray output('c', {bS, oH, oW, oC}, nd4j::DataType::FLOAT32);
|
||||
// NDArray weights('c', {kH, kW, iC, oC}, nd4j::DataType::FLOAT32); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
NDArray weights('c', {oC, iC, kH, kW}, nd4j::DataType::FLOAT32);
|
||||
NDArray bias('c', {oC}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(-100, 0.1);
|
||||
mean.linspace(-50, 0.15);
|
||||
variance.linspace(-5, 0.2);
|
||||
gamma = 1.5;
|
||||
beta = -2.5;
|
||||
input = 5.;
|
||||
weights = 3.;
|
||||
bias = 1.;
|
||||
|
||||
// warm up
|
||||
ops::helpers::batchnorm(&input, &mean, &variance, &gamma, &beta, &output, {1}, 1e-5);
|
||||
nd4j::ops::conv2d op;
|
||||
auto err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||
|
||||
auto timeStart = std::chrono::system_clock::now();
|
||||
for (int i = 0; i < N; ++i)
|
||||
ops::helpers::batchnorm(&input, &mean, &variance, &gamma, &beta, &output, {1}, 1e-5);
|
||||
|
||||
err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||
auto timeEnd = std::chrono::system_clock::now();
|
||||
auto time = std::chrono::duration_cast<std::chrono::microseconds> ((timeEnd - timeStart)/N).count();
|
||||
|
||||
printf("time: %li \n", time);
|
||||
auto time = std::chrono::duration_cast<std::chrono::microseconds> ((timeEnd - timeStart) / N).count();
|
||||
|
||||
printf("time: %i \n", time);
|
||||
}
|
||||
|
||||
|
||||
*/
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue