- make agreement between our and mkl api dilation/padding formulas (#47)
Signed-off-by: Yurii <iuriish@yahoo.com>master
parent
c5b912bddf
commit
62d8e0d409
|
@ -194,6 +194,54 @@ namespace nd4j {
|
|||
|
||||
}
|
||||
|
||||
static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int isSameMode, int& pH, int& pW, int& dH, int& dW) {
|
||||
|
||||
if(kH != 1) {
|
||||
if(isSameMode) {
|
||||
pH = (oH - 1) * sH - iH + kH - pH;
|
||||
dH = dH - 1;
|
||||
}
|
||||
else
|
||||
dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
|
||||
}
|
||||
if(kW != 1) {
|
||||
if(isSameMode) {
|
||||
pW = (oW - 1) * sW - iW + kW - pW;
|
||||
dW = dW - 1;
|
||||
}
|
||||
else
|
||||
dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int isSameMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) {
|
||||
|
||||
if(kD != 1) {
|
||||
if(isSameMode) {
|
||||
pD = (oD - 1) * sD - iD + kD - pD;
|
||||
dD = dD - 1;
|
||||
}
|
||||
else
|
||||
dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1);
|
||||
}
|
||||
if(kH != 1) {
|
||||
if(isSameMode) {
|
||||
pH = (oH - 1) * sH - iH + kH - pH;
|
||||
dH = dH - 1;
|
||||
}
|
||||
else
|
||||
dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
|
||||
}
|
||||
if(kW != 1) {
|
||||
if(isSameMode) {
|
||||
pW = (oW - 1) * sW - iW + kW - pW;
|
||||
dW = dW - 1;
|
||||
}
|
||||
else
|
||||
dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
|
||||
}
|
||||
}
|
||||
|
||||
static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||
|
||||
// static void conv2d(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
|
||||
|
|
|
@ -46,10 +46,13 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
|
||||
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
|
||||
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
|
||||
|
||||
mkldnn::memory::dims strides = { sH, sW };
|
||||
mkldnn::memory::dims dilation = { dH - 1, dW - 1};
|
||||
mkldnn::memory::dims padding = { pH, pW };
|
||||
mkldnn::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||
mkldnn::memory::dims padding_r = { pHmkl, pWmkl };
|
||||
mkldnn::memory::dims dilation = { dHmkl, dWmkl };
|
||||
|
||||
// input type
|
||||
mkldnn::memory::data_type xType;
|
||||
|
@ -190,11 +193,13 @@ static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
|
||||
mkldnn::memory::dims strides = { sH, sW };
|
||||
mkldnn::memory::dims dilation = { dH - 1, dW - 1 };
|
||||
mkldnn::memory::dims padding = { pH, pW };
|
||||
mkldnn::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
|
||||
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
|
||||
|
||||
mkldnn::memory::dims strides = { sH, sW };
|
||||
mkldnn::memory::dims padding = { pH, pW };
|
||||
mkldnn::memory::dims padding_r = { pHmkl, pWmkl };
|
||||
mkldnn::memory::dims dilation = { dHmkl, dWmkl };
|
||||
// input type
|
||||
mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
// weights type
|
||||
|
@ -425,7 +430,6 @@ PLATFORM_CHECK(deconv2d) {
|
|||
|
||||
return block.isUseMKLDNN() && (
|
||||
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
||||
(xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF ) ||
|
||||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -47,10 +47,13 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
|
||||
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
|
||||
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
|
||||
|
||||
mkldnn::memory::dims strides = { sD, sH, sW };
|
||||
mkldnn::memory::dims dilation = { dD - 1, dH - 1, dW - 1};
|
||||
mkldnn::memory::dims padding = { pD, pH, pW };
|
||||
mkldnn::memory::dims padding_r = {(iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||
mkldnn::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
|
||||
mkldnn::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
|
||||
|
||||
// input type
|
||||
mkldnn::memory::data_type xType;
|
||||
|
@ -194,10 +197,13 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
|
||||
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
|
||||
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
|
||||
|
||||
mkldnn::memory::dims strides = { sD, sH, sW };
|
||||
mkldnn::memory::dims dilation = { dD - 1, dH - 1, dW - 1 };
|
||||
mkldnn::memory::dims padding = { pD, pH, pW };
|
||||
mkldnn::memory::dims padding_r = {(iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||
mkldnn::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
|
||||
mkldnn::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
|
||||
|
||||
// input type
|
||||
mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||
|
@ -438,7 +444,6 @@ PLATFORM_CHECK(deconv3d) {
|
|||
|
||||
return block.isUseMKLDNN() && (
|
||||
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
||||
(xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF ) ||
|
||||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include <mkldnn_types.h>
|
||||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace mkldnn;
|
||||
|
||||
|
@ -154,6 +155,14 @@ namespace nd4j {
|
|||
mkldnn::memory::dims conv_bias_tz = { oC };
|
||||
mkldnn::memory::dims conv_dst_tz = { bS, oC, oH, oW };
|
||||
|
||||
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
|
||||
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(iH, iW, oH, oW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
|
||||
|
||||
conv_strides = { sH, sW };
|
||||
conv_padding = { pH, pW };
|
||||
conv_padding_r = { pHmkl, pWmkl };
|
||||
conv_dilation = { dHmkl, dWmkl };
|
||||
|
||||
conv_strides = { sH, sW };
|
||||
conv_padding = { pH, pW };
|
||||
conv_dilation = { dH-1, dW-1};
|
||||
|
@ -234,12 +243,13 @@ namespace nd4j {
|
|||
mkldnn::memory::dims conv_bias_tz = { oC };
|
||||
mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
|
||||
|
||||
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
|
||||
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
|
||||
|
||||
conv_strides = { sD, sH, sW };
|
||||
conv_dilation = { dD-1, dH-1, dW-1};
|
||||
conv_padding = { pD, pH, pW };
|
||||
conv_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
||||
(oH - 1) * sH - iH + kH - pH,
|
||||
(oW - 1) * sW - iW + kW - pW };
|
||||
conv_padding_r = { pDmkl, pHmkl, pWmkl };
|
||||
conv_dilation = { dDmkl, dHmkl, dWmkl };
|
||||
|
||||
auto type = mkldnn::memory::data_type::f32;
|
||||
auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc;
|
||||
|
|
|
@ -2137,9 +2137,9 @@ TEST_F(ConvolutionTests1, deconv2d_test1) {
|
|||
int paddingMode = 0; // 1-SAME, 0-VALID;
|
||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
|
||||
auto weights = NDArrayFactory::create<double>('c', {kH, kW, oC, iC});
|
||||
auto exp = NDArrayFactory::create<double>('c', {bS, oH, oW, oC}, { 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75,
|
||||
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
||||
auto weights = NDArrayFactory::create<float>('c', {kH, kW, oC, iC});
|
||||
auto exp = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}, { 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75,
|
||||
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
|
||||
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
|
||||
52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75,
|
||||
|
@ -2170,9 +2170,9 @@ TEST_F(ConvolutionTests1, deconv2d_test2) {
|
|||
int paddingMode = 1; // 1-SAME, 0-VALID;
|
||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {bS, oH, oW, oC});
|
||||
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, oC});
|
||||
auto exp = NDArrayFactory::create<double>('c', {bS, iH, iW, iC}, {2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 ,
|
||||
auto input = NDArrayFactory::create<float>('c', {bS, oH, oW, oC});
|
||||
auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, oC});
|
||||
auto exp = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}, {2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 ,
|
||||
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
|
||||
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
|
||||
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
|
||||
|
@ -2194,6 +2194,39 @@ TEST_F(ConvolutionTests1, deconv2d_test2) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(ConvolutionTests1, deconv2d_test3) {
|
||||
|
||||
int bS=1, oH=5,oW=5, oC=3,iC=2, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=2,dW=2;
|
||||
int iH=3,iW=3;
|
||||
int paddingMode = 0; // 1-SAME, 0-VALID;
|
||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||
|
||||
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
||||
auto weights = NDArrayFactory::create<float>('c', {kH, kW, oC, iC});
|
||||
auto bias = NDArrayFactory::create<float>('c', {oC});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}, {-2.9, -6.8, -10.7, -2.6, -6.1, -9.6, -16.9, -23.9, -30.9, -13.1, -16.6, -20.1, -11.6, -14.7, -17.8, -2.0, -4.7, -7.4, -1.7, -4.0, -6.3, -11.5, -16.1,
|
||||
-20.7, -8.6, -10.9, -13.2, -7.1, -9.0, -10.9, -27.4, -32.8, -38.2, -24.4, -29.0, -33.6, -65.0, -74.2, -83.4, -38.2, -42.8, -47.4,
|
||||
-32.8, -36.6, -40.4, -18.2, -20.9, -23.6, -15.5, -17.8, -20.1, -39.1, -43.7, -48.3, -22.4, -24.7, -27.0, -18.5, -20.4, -22.3, -10.1, -11.6, -13.1,
|
||||
-7.4, -8.5, -9.6, -19.3, -21.5, -23.7, -10.7, -11.8, -12.9, -6.8, -7.5, -8.2});
|
||||
|
||||
input.linspace(-10, 0.5);
|
||||
weights.linspace(0.1, 0.1);
|
||||
bias = 0.2;
|
||||
|
||||
nd4j::ops::deconv2d op;
|
||||
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) {
|
||||
|
||||
|
|
|
@ -567,6 +567,52 @@ TEST_F(ConvolutionTests2, deconv3d_test4) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(ConvolutionTests2, deconv3d_test5) {
|
||||
|
||||
int bS=1, oD=5,oH=5,oW=5, oC=3,iC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2;
|
||||
int iD=3,iH=3,iW=3;
|
||||
int paddingMode = 0; // 1-SAME, 0-VALID;
|
||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||
|
||||
auto input = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC});
|
||||
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, oC, iC});
|
||||
auto bias = NDArrayFactory::create<float>('c', {oC});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {bS, oD, oH, oW, oC}, {-2.9, -6.8, -10.7, -2.6, -6.1, -9.6, -16.9, -23.9, -30.9, -13.1, -16.6, -20.1, -11.6, -14.7, -17.8, -2.0, -4.7, -7.4, -1.7, -4.0, -6.3, -11.5,
|
||||
-16.1, -20.7, -8.6, -10.9, -13.2, -7.1, -9.0, -10.9, -27.4, -32.8, -38.2, -24.4, -29.0, -33.6, -65.0, -74.2, -83.4, -38.2, -42.8, -47.4, -32.8,
|
||||
-36.6, -40.4, -18.2, -20.9, -23.6, -15.5, -17.8, -20.1, -39.1, -43.7, -48.3, -22.4, -24.7, -27.0, -18.5, -20.4, -22.3, -10.1, -11.6, -13.1, -7.4,
|
||||
-8.5, -9.6, -19.3, -21.5, -23.7, -10.7, -11.8, -12.9, -6.8, -7.5, -8.2, -0.2, -0.5, -0.8, 0.1, 0.2, 0.3, -0.7, -0.5, -0.3, 0.4, 0.5, 0.6, 1.9, 2.4,
|
||||
2.9, 0.7, 1.6, 2.5, 1.0, 2.3, 3.6, 4.7, 7.3, 9.9, 4.9, 6.2, 7.5, 6.4, 8.1, 9.8, -0.4, 1.4, 3.2, 2.6, 5.2, 7.8, 10.6, 15.8, 21.0, 10.4, 13.0, 15.6,
|
||||
15.8, 19.2, 22.6, 6.1, 7.0, 7.9, 8.8, 10.1, 11.4, 20.3, 22.9, 25.5, 12.7, 14.0, 15.3, 16.6, 18.3, 20.0, 14.2, 16.3, 18.4, 16.9, 19.4, 21.9, 40.1,
|
||||
45.1, 50.1, 24.4, 26.9, 29.4, 28.3, 31.2, 34.1, -47.2, -47.8, -48.4, -41.8, -41.6, -41.4, -85.4, -85., -84.6, -41.2, -41.0, -40.8, -33.4, -32.4, -31.4,
|
||||
-31., -29.2, -27.4, -25.6, -23.0, -20.4, -45.8, -40.6, -35.4, -17.8, -15.2, -12.6, -10.0, -6.6, -3.2, -65.6, -62.0, -58.4, -50.0, -44.8, -39.6, -89.2,
|
||||
-78.8, -68.4, -34.4, -29.2, -24., -14.0, -7.2, -0.4, -20.2, -18.4, -16.6, -10., -7.4, -4.8, -14.6, -9.4, -4.2, -2.2, 0.4, 3.0, 10.4, 13.8, 17.2, 10.4,
|
||||
14.6, 18.8, 20.6, 25.6, 30.6, 53.8, 63.8, 73.8, 35.6, 40.6, 45.6, 48.2, 54.0, 59.8, -3.8, -4.1, -4.4, 1.3, 1.4, 1.5, 1.7, 1.9, 2.1, 1.6, 1.7, 1.8, 7.9,
|
||||
8.4, 8.9, 11.5, 12.4, 13.3, 16.6, 17.9, 19.2, 35.9, 38.5, 41.1, 20.5, 21.8, 23.1, 26.8, 28.5, 30.2, 21.2, 23.0, 24.8, 33.8, 36.4, 39.0, 73.0, 78.2,
|
||||
83.4, 41.6, 44.2, 46.8, 56.6, 60.0, 63.4, 16.9, 17.8, 18.7, 24.4, 25.7, 27., 51.5, 54.1, 56.7, 28.3, 29.6, 30.9, 37.0, 38.7, 40.4, 39.4, 41.5,
|
||||
43.6, 46.9, 49.4, 51.9, 100.1, 105.1, 110.1, 54.4, 56.9, 59.4, 63.1, 66.0, 68.9, 42.1, 45.4, 48.7, 47.2, 50.9, 54.6, 104.3, 111.7,
|
||||
119.1, 58.3, 62.0, 65.7, 64.6, 68.7, 72.8, 57.4, 61.9, 66.4, 62.5, 67.4, 72.3, 138.5, 148.3, 158.1, 77.2, 82.1, 87.0, 83.5, 88.8, 94.1,
|
||||
134.6, 143.6, 152.6, 147.2, 157.0, 166.8, 321.4, 341.0, 360.6, 176.6, 186.4, 196.2, 191.6, 202.2, 212.8, 84.4, 88.9,
|
||||
93.4, 91.9, 96.8, 101.7, 197.3, 207.1, 216.9, 106.6, 111.5, 116.4, 115.3, 120.6, 125.9, 106.9, 112.6, 118.3, 114.4, 120.5, 126.6, 245.9, 258.1, 270.3, 132.7, 138.8, 144.9, 141.4, 147.9, 154.4});
|
||||
|
||||
input.linspace(-10, 0.5);
|
||||
weights.linspace(0.1, 0.1);
|
||||
bias = 0.2;
|
||||
|
||||
nd4j::ops::deconv3d op;
|
||||
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
|
||||
auto output = results->at(0);
|
||||
// output->printBuffer();
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(ConvolutionTests2, deconv3d_bp_test1) {
|
||||
|
||||
|
|
Loading…
Reference in New Issue