- make agreement between our and mkl api dilation/padding formulas (#47)

Signed-off-by: Yurii <iuriish@yahoo.com>
master
Yurii Shyrma 2019-11-14 19:21:22 +02:00 committed by raver119
parent c5b912bddf
commit 62d8e0d409
6 changed files with 168 additions and 22 deletions

View File

@ -194,6 +194,54 @@ namespace nd4j {
} }
static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int isSameMode, int& pH, int& pW, int& dH, int& dW) {
if(kH != 1) {
if(isSameMode) {
pH = (oH - 1) * sH - iH + kH - pH;
dH = dH - 1;
}
else
dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
}
if(kW != 1) {
if(isSameMode) {
pW = (oW - 1) * sW - iW + kW - pW;
dW = dW - 1;
}
else
dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
}
}
static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int isSameMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) {
if(kD != 1) {
if(isSameMode) {
pD = (oD - 1) * sD - iD + kD - pD;
dD = dD - 1;
}
else
dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1);
}
if(kH != 1) {
if(isSameMode) {
pH = (oH - 1) * sH - iH + kH - pH;
dH = dH - 1;
}
else
dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
}
if(kW != 1) {
if(isSameMode) {
pW = (oW - 1) * sW - iW + kW - pW;
dW = dW - 1;
}
else
dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
}
}
static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
// static void conv2d(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs); // static void conv2d(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);

View File

@ -46,10 +46,13 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
mkldnn::memory::dims strides = { sH, sW }; mkldnn::memory::dims strides = { sH, sW };
mkldnn::memory::dims dilation = { dH - 1, dW - 1};
mkldnn::memory::dims padding = { pH, pW }; mkldnn::memory::dims padding = { pH, pW };
mkldnn::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; mkldnn::memory::dims padding_r = { pHmkl, pWmkl };
mkldnn::memory::dims dilation = { dHmkl, dWmkl };
// input type // input type
mkldnn::memory::data_type xType; mkldnn::memory::data_type xType;
@ -190,11 +193,13 @@ static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
mkldnn::memory::dims strides = { sH, sW }; int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
mkldnn::memory::dims dilation = { dH - 1, dW - 1 }; ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
mkldnn::memory::dims padding = { pH, pW };
mkldnn::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
mkldnn::memory::dims strides = { sH, sW };
mkldnn::memory::dims padding = { pH, pW };
mkldnn::memory::dims padding_r = { pHmkl, pWmkl };
mkldnn::memory::dims dilation = { dHmkl, dWmkl };
// input type // input type
mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
// weights type // weights type
@ -425,7 +430,6 @@ PLATFORM_CHECK(deconv2d) {
return block.isUseMKLDNN() && ( return block.isUseMKLDNN() && (
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
(xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF ) ||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
); );
} }

View File

@ -47,10 +47,13 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
mkldnn::memory::dims strides = { sD, sH, sW }; mkldnn::memory::dims strides = { sD, sH, sW };
mkldnn::memory::dims dilation = { dD - 1, dH - 1, dW - 1};
mkldnn::memory::dims padding = { pD, pH, pW }; mkldnn::memory::dims padding = { pD, pH, pW };
mkldnn::memory::dims padding_r = {(iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; mkldnn::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
mkldnn::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
// input type // input type
mkldnn::memory::data_type xType; mkldnn::memory::data_type xType;
@ -194,10 +197,13 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
mkldnn::memory::dims strides = { sD, sH, sW }; mkldnn::memory::dims strides = { sD, sH, sW };
mkldnn::memory::dims dilation = { dD - 1, dH - 1, dW - 1 };
mkldnn::memory::dims padding = { pD, pH, pW }; mkldnn::memory::dims padding = { pD, pH, pW };
mkldnn::memory::dims padding_r = {(iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; mkldnn::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
mkldnn::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
// input type // input type
mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
@ -438,7 +444,6 @@ PLATFORM_CHECK(deconv3d) {
return block.isUseMKLDNN() && ( return block.isUseMKLDNN() && (
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
(xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF ) ||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
); );
} }

View File

@ -20,6 +20,7 @@
#include <mkldnn_types.h> #include <mkldnn_types.h>
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn; using namespace mkldnn;
@ -154,6 +155,14 @@ namespace nd4j {
mkldnn::memory::dims conv_bias_tz = { oC }; mkldnn::memory::dims conv_bias_tz = { oC };
mkldnn::memory::dims conv_dst_tz = { bS, oC, oH, oW }; mkldnn::memory::dims conv_dst_tz = { bS, oC, oH, oW };
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(iH, iW, oH, oW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
conv_strides = { sH, sW };
conv_padding = { pH, pW };
conv_padding_r = { pHmkl, pWmkl };
conv_dilation = { dHmkl, dWmkl };
conv_strides = { sH, sW }; conv_strides = { sH, sW };
conv_padding = { pH, pW }; conv_padding = { pH, pW };
conv_dilation = { dH-1, dW-1}; conv_dilation = { dH-1, dW-1};
@ -234,12 +243,13 @@ namespace nd4j {
mkldnn::memory::dims conv_bias_tz = { oC }; mkldnn::memory::dims conv_bias_tz = { oC };
mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW }; mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
conv_strides = { sD, sH, sW }; conv_strides = { sD, sH, sW };
conv_dilation = { dD-1, dH-1, dW-1};
conv_padding = { pD, pH, pW }; conv_padding = { pD, pH, pW };
conv_padding_r = { (oD - 1) * sD - iD + kD - pD, conv_padding_r = { pDmkl, pHmkl, pWmkl };
(oH - 1) * sH - iH + kH - pH, conv_dilation = { dDmkl, dHmkl, dWmkl };
(oW - 1) * sW - iW + kW - pW };
auto type = mkldnn::memory::data_type::f32; auto type = mkldnn::memory::data_type::f32;
auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc; auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc;

View File

@ -2137,9 +2137,9 @@ TEST_F(ConvolutionTests1, deconv2d_test1) {
int paddingMode = 0; // 1-SAME, 0-VALID; int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC}); auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<double>('c', {kH, kW, oC, iC}); auto weights = NDArrayFactory::create<float>('c', {kH, kW, oC, iC});
auto exp = NDArrayFactory::create<double>('c', {bS, oH, oW, oC}, { 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75, auto exp = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}, { 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75, 52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75,
@ -2170,9 +2170,9 @@ TEST_F(ConvolutionTests1, deconv2d_test2) {
int paddingMode = 1; // 1-SAME, 0-VALID; int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<double>('c', {bS, oH, oW, oC}); auto input = NDArrayFactory::create<float>('c', {bS, oH, oW, oC});
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, oC}); auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, oC});
auto exp = NDArrayFactory::create<double>('c', {bS, iH, iW, iC}, {2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , auto exp = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}, {2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
@ -2194,6 +2194,39 @@ TEST_F(ConvolutionTests1, deconv2d_test2) {
delete results; delete results;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, deconv2d_test3) {
int bS=1, oH=5,oW=5, oC=3,iC=2, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=2,dW=2;
int iH=3,iW=3;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<float>('c', {kH, kW, oC, iC});
auto bias = NDArrayFactory::create<float>('c', {oC});
auto exp = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}, {-2.9, -6.8, -10.7, -2.6, -6.1, -9.6, -16.9, -23.9, -30.9, -13.1, -16.6, -20.1, -11.6, -14.7, -17.8, -2.0, -4.7, -7.4, -1.7, -4.0, -6.3, -11.5, -16.1,
-20.7, -8.6, -10.9, -13.2, -7.1, -9.0, -10.9, -27.4, -32.8, -38.2, -24.4, -29.0, -33.6, -65.0, -74.2, -83.4, -38.2, -42.8, -47.4,
-32.8, -36.6, -40.4, -18.2, -20.9, -23.6, -15.5, -17.8, -20.1, -39.1, -43.7, -48.3, -22.4, -24.7, -27.0, -18.5, -20.4, -22.3, -10.1, -11.6, -13.1,
-7.4, -8.5, -9.6, -19.3, -21.5, -23.7, -10.7, -11.8, -12.9, -6.8, -7.5, -8.2});
input.linspace(-10, 0.5);
weights.linspace(0.1, 0.1);
bias = 0.2;
nd4j::ops::deconv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
ASSERT_EQ(Status::OK(), results->status());
auto output = results->at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete results;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) { TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) {

View File

@ -567,6 +567,52 @@ TEST_F(ConvolutionTests2, deconv3d_test4) {
delete results; delete results;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests2, deconv3d_test5) {
int bS=1, oD=5,oH=5,oW=5, oC=3,iC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2;
int iD=3,iH=3,iW=3;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC});
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, oC, iC});
auto bias = NDArrayFactory::create<float>('c', {oC});
auto exp = NDArrayFactory::create<float>('c', {bS, oD, oH, oW, oC}, {-2.9, -6.8, -10.7, -2.6, -6.1, -9.6, -16.9, -23.9, -30.9, -13.1, -16.6, -20.1, -11.6, -14.7, -17.8, -2.0, -4.7, -7.4, -1.7, -4.0, -6.3, -11.5,
-16.1, -20.7, -8.6, -10.9, -13.2, -7.1, -9.0, -10.9, -27.4, -32.8, -38.2, -24.4, -29.0, -33.6, -65.0, -74.2, -83.4, -38.2, -42.8, -47.4, -32.8,
-36.6, -40.4, -18.2, -20.9, -23.6, -15.5, -17.8, -20.1, -39.1, -43.7, -48.3, -22.4, -24.7, -27.0, -18.5, -20.4, -22.3, -10.1, -11.6, -13.1, -7.4,
-8.5, -9.6, -19.3, -21.5, -23.7, -10.7, -11.8, -12.9, -6.8, -7.5, -8.2, -0.2, -0.5, -0.8, 0.1, 0.2, 0.3, -0.7, -0.5, -0.3, 0.4, 0.5, 0.6, 1.9, 2.4,
2.9, 0.7, 1.6, 2.5, 1.0, 2.3, 3.6, 4.7, 7.3, 9.9, 4.9, 6.2, 7.5, 6.4, 8.1, 9.8, -0.4, 1.4, 3.2, 2.6, 5.2, 7.8, 10.6, 15.8, 21.0, 10.4, 13.0, 15.6,
15.8, 19.2, 22.6, 6.1, 7.0, 7.9, 8.8, 10.1, 11.4, 20.3, 22.9, 25.5, 12.7, 14.0, 15.3, 16.6, 18.3, 20.0, 14.2, 16.3, 18.4, 16.9, 19.4, 21.9, 40.1,
45.1, 50.1, 24.4, 26.9, 29.4, 28.3, 31.2, 34.1, -47.2, -47.8, -48.4, -41.8, -41.6, -41.4, -85.4, -85., -84.6, -41.2, -41.0, -40.8, -33.4, -32.4, -31.4,
-31., -29.2, -27.4, -25.6, -23.0, -20.4, -45.8, -40.6, -35.4, -17.8, -15.2, -12.6, -10.0, -6.6, -3.2, -65.6, -62.0, -58.4, -50.0, -44.8, -39.6, -89.2,
-78.8, -68.4, -34.4, -29.2, -24., -14.0, -7.2, -0.4, -20.2, -18.4, -16.6, -10., -7.4, -4.8, -14.6, -9.4, -4.2, -2.2, 0.4, 3.0, 10.4, 13.8, 17.2, 10.4,
14.6, 18.8, 20.6, 25.6, 30.6, 53.8, 63.8, 73.8, 35.6, 40.6, 45.6, 48.2, 54.0, 59.8, -3.8, -4.1, -4.4, 1.3, 1.4, 1.5, 1.7, 1.9, 2.1, 1.6, 1.7, 1.8, 7.9,
8.4, 8.9, 11.5, 12.4, 13.3, 16.6, 17.9, 19.2, 35.9, 38.5, 41.1, 20.5, 21.8, 23.1, 26.8, 28.5, 30.2, 21.2, 23.0, 24.8, 33.8, 36.4, 39.0, 73.0, 78.2,
83.4, 41.6, 44.2, 46.8, 56.6, 60.0, 63.4, 16.9, 17.8, 18.7, 24.4, 25.7, 27., 51.5, 54.1, 56.7, 28.3, 29.6, 30.9, 37.0, 38.7, 40.4, 39.4, 41.5,
43.6, 46.9, 49.4, 51.9, 100.1, 105.1, 110.1, 54.4, 56.9, 59.4, 63.1, 66.0, 68.9, 42.1, 45.4, 48.7, 47.2, 50.9, 54.6, 104.3, 111.7,
119.1, 58.3, 62.0, 65.7, 64.6, 68.7, 72.8, 57.4, 61.9, 66.4, 62.5, 67.4, 72.3, 138.5, 148.3, 158.1, 77.2, 82.1, 87.0, 83.5, 88.8, 94.1,
134.6, 143.6, 152.6, 147.2, 157.0, 166.8, 321.4, 341.0, 360.6, 176.6, 186.4, 196.2, 191.6, 202.2, 212.8, 84.4, 88.9,
93.4, 91.9, 96.8, 101.7, 197.3, 207.1, 216.9, 106.6, 111.5, 116.4, 115.3, 120.6, 125.9, 106.9, 112.6, 118.3, 114.4, 120.5, 126.6, 245.9, 258.1, 270.3, 132.7, 138.8, 144.9, 141.4, 147.9, 154.4});
input.linspace(-10, 0.5);
weights.linspace(0.1, 0.1);
bias = 0.2;
nd4j::ops::deconv3d op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
ASSERT_EQ(Status::OK(), results->status());
auto output = results->at(0);
// output->printBuffer();
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete results;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests2, deconv3d_bp_test1) { TEST_F(ConvolutionTests2, deconv3d_bp_test1) {