Shyrma depthconv (#156)
* - implementation of depthwise_conv2d (both ff/bp) based on mkl dnn api * - minor corrections in deconv3d Signed-off-by: Yurii <iuriish@yahoo.com> * - remove unnecessary time test Signed-off-by: Yurii <iuriish@yahoo.com> * - update mkl dnn version in cmake Signed-off-by: Yurii <iuriish@yahoo.com> * - take into account several notes given by pr reviewer Signed-off-by: Yurii <iuriish@yahoo.com> * - fix bug in depthwise conv2d op based on mkl Signed-off-by: Yurii <iuriish@yahoo.com>master
parent
bbf88b53dd
commit
cae5ef4180
|
@ -5,7 +5,7 @@ project(mkldnn-download NONE)
|
||||||
include(ExternalProject)
|
include(ExternalProject)
|
||||||
ExternalProject_Add(mkldnn
|
ExternalProject_Add(mkldnn
|
||||||
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
||||||
GIT_TAG v1.1.1
|
GIT_TAG v1.1.2
|
||||||
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
||||||
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
||||||
CONFIGURE_COMMAND ""
|
CONFIGURE_COMMAND ""
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// created by Yurii Shyrma on 08.03.2018
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
|
@ -56,8 +56,8 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) {
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||||
mC = weights->sizeAt(indWmC); // channels multiplier
|
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||||
|
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, mC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC};
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC);
|
REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC);
|
||||||
if (bias)
|
if (bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
@ -79,8 +79,8 @@ DECLARE_SHAPE_FN(depthwise_conv2d) {
|
||||||
Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] = iC*mC
|
Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] = iC*mC
|
||||||
|
|
||||||
const int rank = 4;
|
const int rank = 4;
|
||||||
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]);
|
||||||
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]);
|
||||||
|
|
||||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height
|
||||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width
|
||||||
|
@ -101,17 +101,18 @@ DECLARE_SHAPE_FN(depthwise_conv2d) {
|
||||||
indIOioC = 1; indIiH = 2;
|
indIOioC = 1; indIiH = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int bS = inputShapeInfo[1]; // batch size
|
const int bS = shape::sizeAt(inputShapeInfo, 0); // batch size
|
||||||
const int iH = inputShapeInfo[indIiH+1]; // input height
|
const int iH = shape::sizeAt(inputShapeInfo, indIiH); // input height
|
||||||
const int iW = inputShapeInfo[indIiH+2]; // input width
|
const int iW = shape::sizeAt(inputShapeInfo, indIiH+1); // input width
|
||||||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
const int iC = shape::sizeAt(inputShapeInfo, indIOioC); // input channels
|
||||||
const int mC = weightsShapeInfo[indWmC+1]; // channels multiplier(oC = iC*mC)
|
const int mC = shape::sizeAt(weightsShapeInfo, indWmC); // channels multiplier(oC = iC*mC)
|
||||||
const int oC = iC*mC; // output channels
|
const int oC = iC*mC; // output channels
|
||||||
|
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, mC});
|
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC};
|
||||||
|
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||||
if (biasShapeInfo)
|
if (biasShapeInfo)
|
||||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo));
|
||||||
|
|
||||||
int oH, oW; // output height, width
|
int oH, oW; // output height, width
|
||||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
@ -178,10 +179,10 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) {
|
||||||
int trueoH, trueoW; // correct output height, width
|
int trueoH, trueoW; // correct output height, width
|
||||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, mC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC};
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
if(bias)
|
if(bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
@ -190,8 +191,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(depthwise_conv2d_bp) {
|
DECLARE_SHAPE_FN(depthwise_conv2d_bp) {
|
||||||
|
|
||||||
Nd4jLong* inputShapeInfo = inputShape->at(0);
|
Nd4jLong* inputShapeInfo = inputShape->at(0);
|
||||||
|
@ -200,9 +200,9 @@ DECLARE_SHAPE_FN(depthwise_conv2d_bp) {
|
||||||
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2);
|
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2);
|
||||||
|
|
||||||
const int rank = 4;
|
const int rank = 4;
|
||||||
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo));
|
||||||
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo));
|
||||||
REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo));
|
||||||
|
|
||||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height
|
||||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width
|
||||||
|
@ -223,22 +223,22 @@ DECLARE_SHAPE_FN(depthwise_conv2d_bp) {
|
||||||
indIOioC = 1; indIiH = 2;
|
indIOioC = 1; indIiH = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int bS = inputShapeInfo[1]; // batch size
|
const int bS = shape::sizeAt(inputShapeInfo, 0); // batch size
|
||||||
const int iH = inputShapeInfo[indIiH+1]; // input height
|
const int iH = shape::sizeAt(inputShapeInfo, indIiH); // input height
|
||||||
const int iW = inputShapeInfo[indIiH+2]; // input width
|
const int iW = shape::sizeAt(inputShapeInfo, indIiH+1); // input width
|
||||||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
const int iC = shape::sizeAt(inputShapeInfo, indIOioC); // input channels
|
||||||
const int mC = weightsShapeInfo[indWmC+1]; // channels multiplier(oC = iC*mC)
|
const int mC = shape::sizeAt(weightsShapeInfo, indWmC); // channels multiplier(oC = iC*mC)
|
||||||
const int oC = iC*mC; // output channels
|
const int oC = iC*mC; // output channels
|
||||||
|
|
||||||
int trueoH, trueoW; // correct output height, width
|
int trueoH, trueoW; // correct output height, width
|
||||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1}));
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1});
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, mC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC};
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||||
if(biasShapeInfo)
|
if(biasShapeInfo)
|
||||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo));
|
||||||
|
|
||||||
auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace());
|
auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace());
|
||||||
auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace());
|
auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace());
|
||||||
|
|
|
@ -34,13 +34,13 @@ namespace platforms {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
||||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||||
const int isSameMode) {
|
const int paddingMode) {
|
||||||
|
|
||||||
// input [bS, iH, iW, iC] nchw, mkl doesn't support format nhwc
|
// input [bS, iC, iH, iW] nchw, mkl doesn't support format nhwc
|
||||||
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC]
|
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC]
|
||||||
// bias [oC], may be nullptr
|
// bias [oC], may be nullptr
|
||||||
|
|
||||||
// output [bS, oH, oW, oC] nchw, mkl doesn't support format nhwc
|
// output [bS, oC, oH, oW] nchw, mkl doesn't support format nhwc
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
@ -179,12 +179,12 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||||
const int isSameMode) {
|
const int paddingMode) {
|
||||||
|
|
||||||
// input and gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format
|
// input and gradI [bS, iC, iH, iW], mkl doesn't support ndhwc format
|
||||||
// weights and gradW [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC]
|
// weights and gradW [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC]
|
||||||
// gradB [oC], may be nullptr
|
// gradB [oC], may be nullptr
|
||||||
// gradO [bS, oH, oW, oC]
|
// gradO [bS, oC, oH, oW]
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
@ -368,19 +368,19 @@ PLATFORM_IMPL(deconv2d) {
|
||||||
int pW = INT_ARG(5); // paddings width
|
int pW = INT_ARG(5); // paddings width
|
||||||
int dH = INT_ARG(6); // dilations height
|
int dH = INT_ARG(6); // dilations height
|
||||||
int dW = INT_ARG(7); // dilations width
|
int dW = INT_ARG(7); // dilations width
|
||||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||||
|
|
||||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
||||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
if (bias)
|
if (bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
if(isSameMode){ // SAME
|
if(paddingMode){ // SAME
|
||||||
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||||
}
|
}
|
||||||
|
@ -394,7 +394,7 @@ PLATFORM_IMPL(deconv2d) {
|
||||||
output = new NDArray(output->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
output = new NDArray(output->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode);
|
deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode);
|
||||||
|
|
||||||
delete weights;
|
delete weights;
|
||||||
|
|
||||||
|
@ -419,14 +419,14 @@ PLATFORM_CHECK(deconv2d) {
|
||||||
|
|
||||||
int dH = INT_ARG(6); // dilations height
|
int dH = INT_ARG(6); // dilations height
|
||||||
int dW = INT_ARG(7); // dilations width
|
int dW = INT_ARG(7); // dilations width
|
||||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
|
||||||
const DataType xType = input->dataType();
|
const DataType xType = input->dataType();
|
||||||
const DataType wType = weights->dataType();
|
const DataType wType = weights->dataType();
|
||||||
const DataType zType = output->dataType();
|
const DataType zType = output->dataType();
|
||||||
const DataType bType = bias != nullptr ? bias->dataType() : zType;
|
const DataType bType = bias != nullptr ? bias->dataType() : zType;
|
||||||
|
|
||||||
return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !isSameMode) &&
|
return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !paddingMode) &&
|
||||||
(
|
(
|
||||||
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
||||||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
|
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
|
||||||
|
@ -459,7 +459,7 @@ PLATFORM_IMPL(deconv2d_bp) {
|
||||||
int pW = INT_ARG(5); // paddings width
|
int pW = INT_ARG(5); // paddings width
|
||||||
int dH = INT_ARG(6); // dilations height
|
int dH = INT_ARG(6); // dilations height
|
||||||
int dW = INT_ARG(7); // dilations width
|
int dW = INT_ARG(7); // dilations width
|
||||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
@ -467,7 +467,7 @@ PLATFORM_IMPL(deconv2d_bp) {
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||||
|
|
||||||
int trueoH, trueoW; // true output height, width
|
int trueoH, trueoW; // true output height, width
|
||||||
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
|
||||||
|
|
||||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
||||||
|
@ -476,7 +476,7 @@ PLATFORM_IMPL(deconv2d_bp) {
|
||||||
if(bias)
|
if(bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
if(isSameMode){ // SAME
|
if(paddingMode){ // SAME
|
||||||
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||||
}
|
}
|
||||||
|
@ -492,7 +492,7 @@ PLATFORM_IMPL(deconv2d_bp) {
|
||||||
gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
deconv2dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode);
|
deconv2dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode);
|
||||||
|
|
||||||
delete weights;
|
delete weights;
|
||||||
delete gradW;
|
delete gradW;
|
||||||
|
@ -518,7 +518,7 @@ PLATFORM_CHECK(deconv2d_bp) {
|
||||||
|
|
||||||
int dH = INT_ARG(6); // dilations height
|
int dH = INT_ARG(6); // dilations height
|
||||||
int dW = INT_ARG(7); // dilations width
|
int dW = INT_ARG(7); // dilations width
|
||||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
|
||||||
const DataType xType = input->dataType();
|
const DataType xType = input->dataType();
|
||||||
const DataType wType = weights->dataType();
|
const DataType wType = weights->dataType();
|
||||||
|
@ -528,7 +528,7 @@ PLATFORM_CHECK(deconv2d_bp) {
|
||||||
const DataType gradWType = gradW->dataType();
|
const DataType gradWType = gradW->dataType();
|
||||||
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
|
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
|
||||||
|
|
||||||
return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !isSameMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
|
return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !paddingMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -34,8 +34,7 @@ namespace platforms {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
||||||
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW,
|
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW,
|
||||||
const int pD, const int pH, const int pW, const int dD, const int dH, const int dW,
|
const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
||||||
const int isSameMode) {
|
|
||||||
|
|
||||||
// input [bS, iD, iH, iW, iC] ncdhw, mkl doesn't support format ndhwc
|
// input [bS, iD, iH, iW, iC] ncdhw, mkl doesn't support format ndhwc
|
||||||
// weights [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
|
// weights [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
|
||||||
|
@ -182,8 +181,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||||
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW,
|
const int kD, const int kH, const int kW,
|
||||||
const int isSameMode) {
|
const int sD, const int sH, const int sW,
|
||||||
|
const int pD, const int pH, const int pW,
|
||||||
|
const int dD, const int dH, const int dW) {
|
||||||
|
|
||||||
// input and gradI [bS, iD, iH, iW, iC], mkl doesn't support ndhwc format
|
// input and gradI [bS, iD, iH, iW, iC], mkl doesn't support ndhwc format
|
||||||
// weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
|
// weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
|
||||||
|
@ -408,7 +409,7 @@ PLATFORM_IMPL(deconv3d) {
|
||||||
output = new NDArray(output->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
output = new NDArray(output->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode);
|
deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
||||||
|
|
||||||
delete weights;
|
delete weights;
|
||||||
|
|
||||||
|
@ -509,7 +510,7 @@ PLATFORM_IMPL(deconv3d_bp) {
|
||||||
gradO = new NDArray(gradO->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
gradO = new NDArray(gradO->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode);
|
deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
||||||
|
|
||||||
delete weights;
|
delete weights;
|
||||||
delete gradW;
|
delete gradW;
|
||||||
|
|
|
@ -0,0 +1,505 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
|
||||||
|
using namespace dnnl;
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
||||||
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||||
|
const int paddingMode, const bool isNCHW) {
|
||||||
|
|
||||||
|
// mkl supports only following case: mC = 1, oC = iC
|
||||||
|
|
||||||
|
// input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given
|
||||||
|
// weights [kH, kW, iC, mC], mkl doesn't support this format, so we'll make permute
|
||||||
|
// bias [oC], may be nullptr
|
||||||
|
// output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc
|
||||||
|
// oC = iC*mC
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||||
|
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||||
|
|
||||||
|
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||||
|
|
||||||
|
dnnl::memory::dims strides = { sH, sW };
|
||||||
|
dnnl::memory::dims padding = { pH, pW };
|
||||||
|
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||||
|
dnnl::memory::dims dilation = { dH-1, dW-1};
|
||||||
|
|
||||||
|
// input type
|
||||||
|
dnnl::memory::data_type xType;
|
||||||
|
if(input->dataType() == DataType::FLOAT32)
|
||||||
|
xType = dnnl::memory::data_type::f32;
|
||||||
|
else if(input->dataType() == DataType::HALF)
|
||||||
|
xType = dnnl::memory::data_type::f16;
|
||||||
|
else if(input->dataType() == DataType::UINT8)
|
||||||
|
xType = dnnl::memory::data_type::u8;
|
||||||
|
else
|
||||||
|
xType = dnnl::memory::data_type::s8;
|
||||||
|
|
||||||
|
// weights type
|
||||||
|
dnnl::memory::data_type wType = xType;
|
||||||
|
if(xType == dnnl::memory::data_type::u8)
|
||||||
|
wType = dnnl::memory::data_type::s8;
|
||||||
|
|
||||||
|
// output and bias type (have the same types)
|
||||||
|
dnnl::memory::data_type zType;
|
||||||
|
if(output->dataType() == DataType::FLOAT32)
|
||||||
|
zType = dnnl::memory::data_type::f32;
|
||||||
|
else if(output->dataType() == DataType::HALF)
|
||||||
|
zType = dnnl::memory::data_type::f16;
|
||||||
|
else if(output->dataType() == DataType::UINT8)
|
||||||
|
zType = dnnl::memory::data_type::u8;
|
||||||
|
else if(output->dataType() == DataType::INT8)
|
||||||
|
zType = dnnl::memory::data_type::s8;
|
||||||
|
else
|
||||||
|
zType = dnnl::memory::data_type::s32;
|
||||||
|
|
||||||
|
dnnl::memory::format_tag xzFrmat = dnnl::memory::format_tag::nchw;
|
||||||
|
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw;
|
||||||
|
|
||||||
|
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||||
|
dnnl::memory::dims wDims = {iC, mC, 1, kH, kW};
|
||||||
|
dnnl::memory::dims zDims = {bS, oC, oH, oW};
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// input
|
||||||
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
|
||||||
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format NHWC -> NCHW
|
||||||
|
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 : 3);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
|
||||||
|
|
||||||
|
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||||
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||||
|
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // permute
|
||||||
|
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||||
|
w_user_md.data.format_desc.blocking.strides[2] = 0;
|
||||||
|
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(0);
|
||||||
|
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(1);
|
||||||
|
|
||||||
|
// bias
|
||||||
|
dnnl::memory::desc b_mkl_md;
|
||||||
|
if(bias != nullptr)
|
||||||
|
b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x);
|
||||||
|
|
||||||
|
// output
|
||||||
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat);
|
||||||
|
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||||
|
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 : 3);
|
||||||
|
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1);
|
||||||
|
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// operation primitive description
|
||||||
|
dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto,
|
||||||
|
x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, dnnl::memory> args;
|
||||||
|
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
|
// input
|
||||||
|
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||||
|
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// weights
|
||||||
|
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||||
|
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||||
|
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||||
|
if (wReorder)
|
||||||
|
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||||
|
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||||
|
|
||||||
|
// bias
|
||||||
|
if(bias != nullptr) {
|
||||||
|
auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer());
|
||||||
|
args[DNNL_ARG_BIAS] = b_mkl_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
// output
|
||||||
|
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
|
||||||
|
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||||
|
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||||
|
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||||
|
|
||||||
|
// run calculations
|
||||||
|
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder outputs if necessary
|
||||||
|
if (zReorder)
|
||||||
|
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||||
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||||
|
const int paddingMode, const bool isNCHW) {
|
||||||
|
|
||||||
|
// mkl supports only following case: mC = 1, oC = iC
|
||||||
|
|
||||||
|
// input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given
|
||||||
|
// weights, gradW [kH, kW, iC, mC], mkl doesn't support this format, so we'll make permute
|
||||||
|
// gradB [oC], may be nullptr
|
||||||
|
// gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc
|
||||||
|
// oC = iC*mC
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||||
|
mC = weights->sizeAt(indWmC);
|
||||||
|
|
||||||
|
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||||
|
|
||||||
|
dnnl::memory::dims strides = { sH, sW };
|
||||||
|
dnnl::memory::dims padding = { pH, pW };
|
||||||
|
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||||
|
dnnl::memory::dims dilation = { dH-1, dW-1};
|
||||||
|
|
||||||
|
// input type
|
||||||
|
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||||
|
// weights type
|
||||||
|
dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||||
|
// gradO type
|
||||||
|
dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||||
|
// gradI type
|
||||||
|
dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||||
|
// gradW type
|
||||||
|
dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||||
|
// gradB type
|
||||||
|
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
|
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
|
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw;
|
||||||
|
|
||||||
|
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||||
|
dnnl::memory::dims wDims = {iC, mC, 1, kH, kW};
|
||||||
|
dnnl::memory::dims zDims = {bS, oC, oH, oW};
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// input
|
||||||
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||||
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 : 3);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
|
||||||
|
|
||||||
|
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||||
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||||
|
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // permute
|
||||||
|
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||||
|
w_user_md.data.format_desc.blocking.strides[2] = 0;
|
||||||
|
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(0);
|
||||||
|
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(1);
|
||||||
|
|
||||||
|
// gradO
|
||||||
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||||
|
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 : 3);
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1);
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2);
|
||||||
|
|
||||||
|
// gradI
|
||||||
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||||
|
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 : 3);
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1);
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2);
|
||||||
|
|
||||||
|
// gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||||
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||||
|
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(2); // permute
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3);
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[2] = 0;
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(0);
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(1);
|
||||||
|
|
||||||
|
// gradB
|
||||||
|
dnnl::memory::desc gradB_mkl_md;
|
||||||
|
if(gradB != nullptr)
|
||||||
|
gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// forward primitive description
|
||||||
|
dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||||
|
|
||||||
|
// backward data primitive description
|
||||||
|
dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// backward weights primitive description
|
||||||
|
dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, dnnl::memory> args;
|
||||||
|
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
|
// input
|
||||||
|
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||||
|
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// weights
|
||||||
|
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||||
|
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||||
|
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||||
|
if (wReorder)
|
||||||
|
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||||
|
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||||
|
|
||||||
|
// gradO
|
||||||
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
|
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
|
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
|
if (gradOReorder)
|
||||||
|
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||||
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||||
|
|
||||||
|
// gradI
|
||||||
|
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||||
|
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||||
|
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||||
|
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||||
|
|
||||||
|
// gradW
|
||||||
|
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer());
|
||||||
|
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
||||||
|
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||||
|
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||||
|
|
||||||
|
// gradB
|
||||||
|
if(gradB != nullptr) {
|
||||||
|
auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer());
|
||||||
|
args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
// run backward data calculations
|
||||||
|
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// run backward weights calculations
|
||||||
|
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder gradI if necessary
|
||||||
|
if (gradIReorder)
|
||||||
|
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||||
|
if (gradWReorder)
|
||||||
|
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(depthwise_conv2d) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||||
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC
|
||||||
|
|
||||||
|
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
|
||||||
|
|
||||||
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
|
||||||
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
|
||||||
|
int sH = INT_ARG(2); // strides height
|
||||||
|
int sW = INT_ARG(3); // strides width
|
||||||
|
int pH = INT_ARG(4); // paddings height
|
||||||
|
int pW = INT_ARG(5); // paddings width
|
||||||
|
int dH = INT_ARG(6); // dilations height
|
||||||
|
int dW = INT_ARG(7); // dilations width
|
||||||
|
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||||
|
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||||
|
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||||
|
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC};
|
||||||
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D MKL OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC);
|
||||||
|
if (bias)
|
||||||
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
depthwiseConv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(depthwise_conv2d) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto weights = INPUT_VARIABLE(1);
|
||||||
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
|
||||||
|
|
||||||
|
auto output = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const DataType xType = input->dataType();
|
||||||
|
const DataType wType = weights->dataType();
|
||||||
|
const DataType zType = output->dataType();
|
||||||
|
const DataType bType = bias != nullptr ? bias->dataType() : zType;
|
||||||
|
|
||||||
|
const int mC = weights->sizeAt(3);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && mC == 1 &&
|
||||||
|
(
|
||||||
|
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
||||||
|
(xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF) ||
|
||||||
|
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(depthwise_conv2d_bp) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
|
||||||
|
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||||
|
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
||||||
|
REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf());
|
||||||
|
|
||||||
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
|
||||||
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
|
||||||
|
int sH = INT_ARG(2); // strides height
|
||||||
|
int sW = INT_ARG(3); // strides width
|
||||||
|
int pH = INT_ARG(4); // paddings height
|
||||||
|
int pW = INT_ARG(5); // paddings width
|
||||||
|
int dH = INT_ARG(6); // dilations height
|
||||||
|
int dW = INT_ARG(7); // dilations width
|
||||||
|
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||||
|
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||||
|
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||||
|
|
||||||
|
int trueoH, trueoW; // correct output height, width
|
||||||
|
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
|
||||||
|
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||||
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC};
|
||||||
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
if(bias)
|
||||||
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(depthwise_conv2d_bp) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
|
||||||
|
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
const DataType xType = input->dataType();
|
||||||
|
const DataType wType = weights->dataType();
|
||||||
|
const DataType gradOType = gradO->dataType();
|
||||||
|
|
||||||
|
const DataType gradIType = gradI->dataType();
|
||||||
|
const DataType gradWType = gradW->dataType();
|
||||||
|
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
|
||||||
|
|
||||||
|
const int mC = weights->sizeAt(3);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && mC == 1 && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -77,6 +77,9 @@ namespace nd4j{
|
||||||
DECLARE_PLATFORM(deconv2d_bp);
|
DECLARE_PLATFORM(deconv2d_bp);
|
||||||
|
|
||||||
DECLARE_PLATFORM(deconv3d_bp);
|
DECLARE_PLATFORM(deconv3d_bp);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(depthwise_conv2d);
|
||||||
|
DECLARE_PLATFORM(depthwise_conv2d_bp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1469,223 +1469,6 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TYPED_TEST(TypedConvolutionTests1, depthwise_conv2d_1) {
|
|
||||||
|
|
||||||
int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
||||||
int oC=iC*mC;
|
|
||||||
int oH=4,oW=3;
|
|
||||||
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
||||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
|
|
||||||
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, mC});
|
|
||||||
|
|
||||||
|
|
||||||
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f,
|
|
||||||
13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f,
|
|
||||||
12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f,
|
|
||||||
13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f});
|
|
||||||
input = 2.;
|
|
||||||
weights.linspace(0.1, 0.1);
|
|
||||||
|
|
||||||
nd4j::ops::depthwise_conv2d op;
|
|
||||||
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
||||||
auto* output = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
|
||||||
|
|
||||||
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
||||||
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(ConvolutionTests1, depthwise_conv2d_2) {
|
|
||||||
|
|
||||||
int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
||||||
int oC=iC*mC;
|
|
||||||
int oH=2,oW=2;
|
|
||||||
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
||||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
|
|
||||||
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
|
|
||||||
|
|
||||||
|
|
||||||
auto expOutput = NDArrayFactory::create<double>('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f,
|
|
||||||
13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f});
|
|
||||||
input = 2.;
|
|
||||||
weights.linspace(0.1, 0.1);
|
|
||||||
|
|
||||||
nd4j::ops::depthwise_conv2d op;
|
|
||||||
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
||||||
auto* output = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
|
||||||
|
|
||||||
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
||||||
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(ConvolutionTests1, depthwise_conv2d_3) {
|
|
||||||
|
|
||||||
int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
||||||
int oC=iC*mC;
|
|
||||||
int oH=2,oW=2;
|
|
||||||
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
||||||
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {bS, iC, iH, iW});
|
|
||||||
auto weights = NDArrayFactory::create<double>('c', {mC, iC, kH, kW});
|
|
||||||
auto biases = NDArrayFactory::create<double>('c', {iC*mC}, {1,2,3,4});
|
|
||||||
|
|
||||||
|
|
||||||
auto expOutput = NDArrayFactory::create<double>('c', {bS, oC, oH, oW},{5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8, 5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8});
|
|
||||||
input = 2.;
|
|
||||||
weights.linspace(0.1, 0.1);
|
|
||||||
weights.permutei({2,3,1,0});
|
|
||||||
|
|
||||||
nd4j::ops::depthwise_conv2d op;
|
|
||||||
auto results = op.execute({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
||||||
auto* output = results->at(0);
|
|
||||||
|
|
||||||
// output->printIndexedBuffer();
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
|
||||||
|
|
||||||
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
||||||
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(ConvolutionTests1, depthwise_conv2d_4) {
|
|
||||||
|
|
||||||
int bS=1, iH=111,iW=111, iC=32,mC=1, kH=7,kW=7, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1;
|
|
||||||
int oC=iC*mC;
|
|
||||||
int oH=56,oW=56;
|
|
||||||
|
|
||||||
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
||||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
||||||
|
|
||||||
const float unique = -1000000;
|
|
||||||
|
|
||||||
NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray output('c', {bS, oH, oW, oC}, nd4j::DataType::FLOAT32);
|
|
||||||
input.linspace(0.1, 0.0001);
|
|
||||||
weights = 0.5;
|
|
||||||
output = unique;
|
|
||||||
|
|
||||||
nd4j::ops::depthwise_conv2d op;
|
|
||||||
Nd4jStatus status = op.execute({&input, &weights}, {&output} , {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {});
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
|
||||||
|
|
||||||
for(Nd4jLong i=output.lengthOf()/1.5; i < output.lengthOf(); ++i)
|
|
||||||
ASSERT_EQ(output.e<float>(i) != unique, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(ConvolutionTests1, depthwise_conv2d_5) {
|
|
||||||
|
|
||||||
int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
||||||
int oC=iC*mC;
|
|
||||||
int oH=3,oW=3;
|
|
||||||
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
||||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
|
|
||||||
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
|
|
||||||
|
|
||||||
|
|
||||||
auto expOutput = NDArrayFactory::create<double>('c', {bS, oH, oW, oC}, {20., 24.,28., 32.,16., 18.,44., 48.,52., 56.,28., 30.,28., 30.,32., 34.,17., 18.});
|
|
||||||
input.linspace(1.);
|
|
||||||
weights = 1.;
|
|
||||||
|
|
||||||
nd4j::ops::depthwise_conv2d op;
|
|
||||||
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
||||||
auto output = results->at(0);
|
|
||||||
// output->printIndexedBuffer();
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
|
||||||
|
|
||||||
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
||||||
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(ConvolutionTests1, depthwise_conv2d_6) {
|
|
||||||
|
|
||||||
int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
||||||
int oC=iC*mC;
|
|
||||||
int oH=3,oW=3;
|
|
||||||
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
||||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
||||||
|
|
||||||
NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::DOUBLE);
|
|
||||||
NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::DOUBLE);
|
|
||||||
|
|
||||||
NDArray expOutput('c', {bS, oH, oW, oC}, {20., 24.,28., 32.,16., 18.,44., 48.,52., 56.,28., 30.,28., 30.,32., 34.,17., 18.});
|
|
||||||
input.linspace(1.);
|
|
||||||
weights = 1.;
|
|
||||||
|
|
||||||
nd4j::ops::depthwise_conv2d op;
|
|
||||||
ResultSet* results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
||||||
NDArray* output = results->at(0);
|
|
||||||
// output.printIndexedBuffer();
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
|
||||||
|
|
||||||
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
||||||
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(ConvolutionTests1, depthwise_conv2d_7) {
|
|
||||||
|
|
||||||
int bS=1, iH=3,iW=3, iC=2,mC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
||||||
int oC=iC*mC;
|
|
||||||
int oH=3,oW=3;
|
|
||||||
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
||||||
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
|
||||||
|
|
||||||
NDArray input('c', {bS, iC, iH, iW}, {0.6793503761291504, 0.35508695244789124, 0.842789351940155, 0.20031332969665527, 0.7014986872673035, 0.3106933832168579,
|
|
||||||
0.44793984293937683, 0.9380097389221191, 0.3266739547252655, 0.15187257528305054, 0.3833175301551819, 0.7821229696273804,
|
|
||||||
0.19880719482898712, 0.7985635995864868, 0.16326339542865753, 0.14696824550628662, 0.2608966827392578, 0.13505761325359344});
|
|
||||||
NDArray weights('c', {kH, kW, iC, mC}, {0.1308445781469345, 0.6442840099334717, 0.5698848366737366, 0.19896849989891052});
|
|
||||||
NDArray biases('c', {1,iC*mC}, {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, 0.4270855486392975});
|
|
||||||
|
|
||||||
NDArray expOutput('c', {bS, oC, oH, oW}, {0.7012459761288241, 0.6588178652487691, 0.722631079971582, 0.6385665758716108, 0.7041439625563628, 0.6530092074102978,
|
|
||||||
0.670967162534851, 0.735090151337225, 0.6551001785478623, 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, 0.5054379267801892, 0.8283436386757472,
|
|
||||||
0.5765540302788565, 0.6649797296980537, 0.9807239274294943, 0.586850056971322, 0.261199593183985, 0.3930965634902499, 0.6203697362284615, 0.28794692117826504,
|
|
||||||
0.6297390019475202, 0.26769104886224415, 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, 0.4573034071191504, 0.5033536625992294, 0.5827033826425385,
|
|
||||||
0.4666419179635315, 0.585974550122895, 0.4595698215161401, 0.45632759998045813, 0.4789957702325296, 0.4539577593482922});
|
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::depthwise_conv2d op;
|
|
||||||
auto results = op.execute({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
||||||
auto* output = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
|
||||||
|
|
||||||
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
||||||
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test1) {
|
TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test1) {
|
||||||
|
|
||||||
|
@ -1695,15 +1478,15 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test1) {
|
||||||
int paddingMode = 1; // 1-SAME, 0-VALID;
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
||||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
|
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
||||||
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
|
auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, mC});
|
||||||
auto bias = NDArrayFactory::create<double>('c', {oC}, {1,2,3,4});
|
auto bias = NDArrayFactory::create<float>('c', {oC}, {1,2,3,4});
|
||||||
auto gradO = NDArrayFactory::create<double>('c', {bS, oH, oW, oC});
|
auto gradO = NDArrayFactory::create<float>('c', {bS, oH, oW, oC});
|
||||||
|
|
||||||
auto expGradI = NDArrayFactory::create<double>('c', {bS, iH, iW, iC},{0.07 , 0.19 , 0.348, 0.652, 0.588, 0.956, 0.387, 0.687, 1.326, 2.022, 1.878, 2.67 , 1.071, 1.515, 2.982, 3.966, 3.534, 4.614, 1.606, 1.982, 3.932, 4.748, 4.428, 5.308,
|
NDArray expGradI('c', {bS, iH, iW, iC},{0.07 , 0.19 , 0.348, 0.652, 0.588, 0.956, 0.387, 0.687, 1.326, 2.022, 1.878, 2.67 , 1.071, 1.515, 2.982, 3.966, 3.534, 4.614, 1.606, 1.982, 3.932, 4.748, 4.428, 5.308,
|
||||||
1.126, 1.63 , 3.228, 4.3 , 3.468, 4.604, 3.123, 3.999, 7.95 , 9.798, 8.502, 10.446, 3.807, 4.827, 9.606, 11.742,10.158, 12.39 , 4.198, 4.958, 9.884, 11.468,10.38 , 12.028});
|
1.126, 1.63 , 3.228, 4.3 , 3.468, 4.604, 3.123, 3.999, 7.95 , 9.798, 8.502, 10.446, 3.807, 4.827, 9.606, 11.742,10.158, 12.39 , 4.198, 4.958, 9.884, 11.468,10.38 , 12.028}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
auto expGradW = NDArrayFactory::create<double>('c', {kH, kW, iC, mC},{19.08, 19.44,19.8 , 20.16,12.24, 12.48,12.72, 12.96,22.56, 23.04,23.52, 24. ,14.4 , 14.72,15.04, 15.36,14.76, 15.12,15.48, 15.84, 9.36, 9.6 , 9.84, 10.08});
|
NDArray expGradW('c', {kH, kW, iC, mC},{19.08, 19.44,19.8 , 20.16,12.24, 12.48,12.72, 12.96,22.56, 23.04,23.52, 24. ,14.4 , 14.72,15.04, 15.36,14.76, 15.12,15.48, 15.84, 9.36, 9.6 , 9.84, 10.08}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
input = 2.;
|
input = 2.;
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
@ -1734,14 +1517,14 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test2) {
|
||||||
int paddingMode = 0; // 1-SAME, 0-VALID;
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
||||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
|
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
||||||
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
|
auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, mC});
|
||||||
auto bias = NDArrayFactory::create<double>('c', {oC}, {1,2,3,4});
|
auto bias = NDArrayFactory::create<float>('c', {oC}, {1,2,3,4});
|
||||||
auto gradO = NDArrayFactory::create<double>('c', {bS, oH, oW, oC});
|
auto gradO = NDArrayFactory::create<float>('c', {bS, oH, oW, oC});
|
||||||
|
|
||||||
auto expGradI = NDArrayFactory::create<double>('c', {bS, iH, iW, iC},{0.005, 0.025,0.034, 0.106,0.061, 0.113,0.058, 0.162,0.292, 0.564,0.298, 0.466,0.234, 0.402,0.772, 1.172,0.602, 0.834,0.333, 0.449,0.882, 1.146,0.581, 0.729,
|
NDArray expGradI('c', {bS, iH, iW, iC},{0.005, 0.025,0.034, 0.106,0.061, 0.113,0.058, 0.162,0.292, 0.564,0.298, 0.466,0.234, 0.402,0.772, 1.172,0.602, 0.834,0.333, 0.449,0.882, 1.146,0.581, 0.729,
|
||||||
0.053, 0.137,0.258, 0.458,0.237, 0.353,0.41 , 0.642,1.252, 1.78 ,0.906, 1.202,1.098, 1.394,2.756, 3.412,1.722, 2.082,0.893, 1.073,2.13 , 2.522,1.269, 1.481});
|
0.053, 0.137,0.258, 0.458,0.237, 0.353,0.41 , 0.642,1.252, 1.78 ,0.906, 1.202,1.098, 1.394,2.756, 3.412,1.722, 2.082,0.893, 1.073,2.13 , 2.522,1.269, 1.481}, nd4j::DataType::FLOAT32);
|
||||||
auto expGradW = NDArrayFactory::create<double>('c', {kH, kW, iC, mC},{2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88});
|
NDArray expGradW('c', {kH, kW, iC, mC},{2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
input = 2.;
|
input = 2.;
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
@ -1763,6 +1546,132 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test2) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test3) {
|
||||||
|
|
||||||
|
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
|
||||||
|
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {1, 16});
|
||||||
|
auto grad = NDArrayFactory::create<float>('c', {4, 16, 64, 64});
|
||||||
|
|
||||||
|
auto gradI = in.like();
|
||||||
|
auto gradW = w.like();
|
||||||
|
auto gradB = b.like();
|
||||||
|
|
||||||
|
nd4j:ops::depthwise_conv2d_bp op;
|
||||||
|
auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test4) {
|
||||||
|
|
||||||
|
int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
||||||
|
int oC=iC*mC;
|
||||||
|
int oH=10,oW=10;
|
||||||
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
||||||
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO('c', {bS, oH, oW, oC}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray bias('c', {oC}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(-10, 0.1);
|
||||||
|
weights.linspace(-2, 0.1);
|
||||||
|
gradO.linspace(10, -0.1);
|
||||||
|
|
||||||
|
|
||||||
|
NDArray expGradI('c', {bS, iH, iW, iC},{10.880001, 13.239998, 15.520001, 17.719997, 19.840000, 21.880001, 23.839998, 25.720001, 31.360004, 34.420002, 37.360001, 40.180004, 42.880005, 45.460003, 47.919994, 50.260002, 31.360001, 33.939999, 36.400002, 38.739998, 40.959999, 43.059998, 45.040001, 46.900005, 31.359997, 33.459999, 35.439999, 37.300003, 39.040001, 40.660000, 42.160000, 43.539997, 31.360001, 32.980000, 34.480000, 35.860001, 37.119999, 38.259998, 39.279999, 40.180000, 31.360001, 32.499996, 33.520000, 34.419998, 35.200001, 35.860001, 36.400002, 36.820000, 31.360001, 32.019997, 32.560001, 32.979996, 33.280003, 33.459999, 33.520000, 33.459999, 31.360001, 31.540001, 31.599998, 31.539999, 31.360001, 31.059999, 30.639999, 30.100000, 31.360001, 31.060001, 30.639999, 30.099998, 29.440002, 28.660000, 27.759998, 26.740000, 18.559999, 18.040001, 17.440001, 16.760000, 16.000000, 15.160000, 14.240001, 13.240000, 85.439995, 85.860001, 86.159996, 86.339996, 86.400002, 86.340012, 86.159996, 85.860008, 132.000000, 131.910004, 131.639999, 131.190002, 130.559998, 129.750000, 128.760010, 127.589996, 123.360001, 122.550003, 121.559998, 120.389999, 119.040009, 117.510002, 115.799988, 113.910004, 114.720001, 113.189995, 111.480003, 109.590004, 107.520004, 105.270004, 102.839996, 100.230011, 106.079994, 103.830002, 101.400009, 98.790009, 96.000008,
|
||||||
|
93.030006, 89.879990, 86.549988, 97.439995, 94.469994, 91.319992, 87.990005, 84.479996, 80.789993, 76.919998, 72.870003, 88.800003, 85.110001, 81.239998, 77.190002, 72.960007, 68.550003, 63.959999, 59.190002, 80.160004, 75.750000, 71.160004, 66.389999, 61.440002, 56.309994, 51.000000, 45.510002, 71.519997, 66.389999, 61.079998, 55.590000, 49.919998, 44.070000, 38.040001, 31.830002, 31.680000, 27.780003, 23.760000, 19.619999, 15.360001, 10.980000, 6.480000, 1.859999, 47.040001, 42.660004, 38.160000, 33.540001, 28.799999, 23.939999, 18.960001, 13.860001, 45.599998, 38.310001, 30.840000, 23.190002, 15.360001, 7.349998, -0.840002, -9.210003, 36.959999, 28.950003, 20.759998, 12.390001, 3.839998, -4.889999, -13.799999, -22.890003, 28.320002, 19.589998, 10.680000, 1.590002, -7.680002, -17.129999, -26.759998, -36.570007, 19.680002, 10.230003, 0.599998, -9.210001, -19.199999, -29.370003, -39.720001, -50.250008, 11.039999, 0.869999, -9.480000, -20.010002, -30.719994, -41.610001, -52.679996, -63.930008, 2.400005, -8.489998, -19.560005, -30.809998, -42.239998, -53.849991, -65.639992, -77.610001, -6.239998, -17.849998, -29.639988, -41.609985, -53.760002, -66.090004, -78.599991, -91.290009, -14.879990, -27.209995, -39.720009, -52.410007, -65.279999, -78.330002, -91.559998, -104.969986, -45.119995, -53.820000, -62.639999, -71.580002, -80.640007, -89.819992, -99.119995, -108.540009, 8.639999, -0.540001, -9.839996, -19.259998, -28.799995, -38.459999, -48.240002, -58.140003, -40.799999, -55.289997, -69.960007, -84.810013, -99.840004, -115.050011, -130.440018, -146.010010, -49.439991, -64.650009, -80.040009, -95.610016, -111.360008, -127.290001, -143.399994, -159.690018, -58.080009, -74.009987, -90.119995, -106.409988, -122.880005, -139.530014, -156.360001, -173.369995, -66.720001, -83.369995, -100.199997,
|
||||||
|
-117.209999, -134.399994, -151.769989, -169.319992, -187.049988, -75.360008, -92.729996, -110.279991, -128.009979, -145.920013, -164.009995, -182.279984, -200.729996, -84.000000, -102.089996, -120.360016, -138.809967, -157.440002, -176.249969, -195.240005, -214.410019, -92.639999, -111.449997, -130.440018, -149.610016, -168.960007, -188.489990, -208.200012, -228.090012, -101.279976, -120.809982, -140.519989, -160.410004, -180.480011, -200.730011, -221.160034, -241.770020, -121.920006, -135.420013, -149.040009, -162.779999, -176.640015, -190.619995, -204.719986, -218.940002, -29.760002, -43.739998, -57.840000, -72.059998, -86.400009, -100.860001, -115.439995, -130.140015, -127.199997, -148.890015, -170.760010, -192.809998, -215.040024, -237.450012, -260.039978, -282.809998, -135.839996, -158.250000, -180.840012, -203.610046, -226.559982, -249.690002, -272.999969, -296.489990, -144.479980, -167.609985, -190.920013, -214.410019, -238.080032, -261.929993, -285.959991, -310.169983, -153.119995, -176.969986, -201.000031, -225.210022, -249.599976, -274.170013, -298.920013, -323.849976, -161.760040, -186.330017, -211.079987, -236.009995, -261.120026, -286.410034, -311.879974, -337.530029, -170.400009, -195.689987, -221.159973, -246.809998, -272.639954, -298.650024, -324.840057, -351.209991, -179.039963, -205.050018, -231.240021, -257.609985, -284.160004, -310.890015, -337.799988, -364.890015, -187.680023, -214.410004, -241.319977, -268.410004, -295.679993, -323.130005, -350.760010, -378.570038, -198.720016, -217.019989, -235.440002, -253.979980, -272.640045, -291.419983, -310.319977, -329.339996, -68.159981, -86.939987, -105.840012, -124.860001, -144.000000, -163.260010, -182.639984, -202.140015, -213.600021, -242.489990, -271.559937, -300.809998, -330.239990, -359.849976, -389.639984,
|
||||||
|
-419.610016, -222.240036, -251.849960, -281.640015, -311.609985, -341.760040, -372.089996, -402.600037, -433.290009, -230.880005, -261.210022, -291.719971, -322.410034, -353.280029, -384.329956, -415.559998, -446.970001, -239.519989, -270.570007, -301.800018, -333.209991, -364.800018, -396.570007, -428.520020, -460.650024, -248.160034, -279.929962, -311.880005, -344.010010, -376.320038, -408.809998, -441.479980, -474.330017, -256.799988, -289.289978, -321.960022, -354.809967, -387.839996, -421.050018, -454.440002, -488.009979, -265.440002, -298.650024, -332.040009, -365.609985, -399.360016, -433.290009, -467.399963, -501.689941, -274.080017, -308.009949, -342.119995, -376.409973, -410.880005, -445.530029, -480.359985, -515.369995, -275.520020, -298.619995, -321.839966, -345.179993, -368.640015, -392.220001, -415.919952, -439.740021, -106.560005, -130.140030, -153.840027, -177.659973, -201.599991, -225.660019, -249.840012, -274.140015, -300.000000, -336.090057, -372.360046, -408.809937, -445.440002, -482.250031, -519.240051, -556.410034, -308.640015, -345.450012, -382.440002, -419.609955, -456.959961, -494.489960, -532.200012, -570.089966, -317.280029, -354.809998, -392.520020, -430.410004, -468.480042, -506.729980, -545.159912, -583.770020, -325.920013, -364.169952, -402.600037, -441.210022, -480.000000, -518.970032, -558.119873, -597.449951, -334.559967, -373.529999, -412.679993, -452.009949, -491.519989, -531.209961, -571.080017, -611.129944, -343.200012, -382.889984, -422.760071, -462.809906, -503.039978, -543.449951, -584.039978, -624.809998, -351.839966, -392.250000, -432.839966, -473.609955, -514.560120, -555.689941, -596.999939, -638.489990, -360.480011, -401.610016, -442.920044, -484.409912, -526.080017, -567.929993, -609.959961, -652.169983, -352.320007, -380.220001,
|
||||||
|
-408.239990, -436.380005, -464.639984, -493.019989, -521.519958, -550.139954, -144.960022, -173.339996, -201.839996, -230.459976, -259.200043, -288.059998, -317.039978, -346.140015, -386.399963, -429.690002, -473.159912, -516.809937, -560.640076, -604.650024, -648.839966, -693.210022, -395.039978, -439.050018, -483.239929, -527.609985, -572.159973, -616.890015, -661.799988, -706.890015, -403.680023, -448.409973, -493.320007, -538.410034, -583.680054, -629.129944, -674.760010, -720.570068, -412.320007, -457.769897, -503.399963, -549.210083, -595.199951, -641.369995, -687.720093, -734.250000, -420.960052, -467.130035, -513.479980, -560.010010, -606.720093, -653.610046, -700.680054, -747.930115, -429.599976, -476.489990, -523.559998, -570.809937, -618.239990, -665.849976, -713.640015, -761.609985, -438.239990, -485.850037, -533.640015, -581.610046, -629.760010, -678.089966, -726.600037, -775.289917, -446.880035,-495.210052, -543.719971, -592.410034, -641.279968, -690.330017, -739.559937, -788.970093, -429.120026, -461.819946, -494.639984, -527.580017, -560.640015, -593.820007, -627.119995, -660.540039, -183.360016, -216.540009, -249.839996, -283.260040, -316.800018, -350.459961, -384.239990, -418.139984, -472.800049, -523.289917, -573.959961, -624.809998, -675.839966, -727.050049, -778.440063, -830.010010, -481.440002, -532.649963, -584.040100, -635.609985, -687.359924, -739.290039, -791.399963, -843.689941, -490.079987, -542.010010, -594.119995, -646.410034, -698.880005, -751.529968, -804.359985, -857.369995, -498.720032, -551.369995, -604.200012, -657.210022, -710.400024, -763.770081, -817.319946, -871.050049, -507.359955, -560.729919, -614.280029, -668.010010, -721.919983, -776.010010, -830.280029, -884.730042, -515.999939, -570.089966, -624.360046, -678.809937, -733.440002,
|
||||||
|
-788.250000, -843.239990, -898.410034, -524.639954, -579.449951, -634.440002, -689.609985, -744.960022, -800.489990, -856.200012, -912.090027, -533.280029, -588.810059, -644.520081, -700.409973, -756.480042, -812.730103, -869.159912, -925.769958, -505.920013, -543.420044, -581.040039, -618.780029, -656.640015, -694.620056, -732.719971, -770.940002, -447.359985, -471.559998, -495.840027, -520.200012, -544.640015, -569.159973, -593.760010, -618.440002, -815.359985, -852.140015, -889.040039, -926.059937, -963.200073, -1000.460022, -1037.839966, -1075.339966, -826.879944, -864.139954, -901.519958, -939.019958, -976.640076, -1014.379944, -1052.239990, -1090.219971, -838.400024, -876.140015, -913.999939, -951.979919, -990.080017, -1028.299927, -1066.640015, -1105.099976, -849.919983, -888.140015, -926.479980, -964.939941, -1003.520081, -1042.219971, -1081.040039, -1119.979980, -861.440063, -900.140015, -938.960022,-977.899963, -1016.960022, -1056.140015, -1095.440063, -1134.859985, -872.960022, -912.140015, -951.439941, -990.859985, -1030.400024, -1070.060059, -1109.839844, -1149.739990, -884.479980, -924.140015, -963.919922, -1003.819946, -1043.839966, -1083.979980, -1124.239990, -1164.619995, -896.000000, -936.140015, -976.399963, -1016.780029, -1057.280029, -1097.899902, -1138.640015, -1179.500122, -705.919983, -733.000000, -760.159912, -787.400024, -814.719971, -842.119995, -869.599976, -897.160034}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expGradW('c', {kH, kW, iC, mC},{-104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875,
|
||||||
|
-107702.734375, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104824.789062,
|
||||||
|
-105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -126744.000000, -127277.710938, -127813.187500,
|
||||||
|
-128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -140944.000000, -141536.984375, -142131.984375, -142729.000000, -143328.000000,
|
||||||
|
-143929.015625, -144532.000000, -145137.000000, -126744.000000, -127277.710938, -127813.187500, -128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -104824.789062, -105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875, -107702.734375}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expGradB('c', {oC}, {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::depthwise_conv2d_bp op;
|
||||||
|
ResultSet* results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
|
NDArray* gradI = results->at(0);
|
||||||
|
NDArray* gradW = results->at(1);
|
||||||
|
NDArray* gradB = results->at(2);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
||||||
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
||||||
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expGradB.isSameShape(gradB));
|
||||||
|
ASSERT_TRUE(expGradB.equalsTo(gradB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test5) {
|
||||||
|
|
||||||
|
int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
||||||
|
int oC=iC*mC;
|
||||||
|
int oH=10,oW=10;
|
||||||
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
||||||
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO('c', {bS, oC, oH, oW}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray bias('c', {oC}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(-10, 0.1);
|
||||||
|
weights.linspace(-2, 0.1);
|
||||||
|
gradO.linspace(10, -0.1);
|
||||||
|
|
||||||
|
|
||||||
|
NDArray expGradI('c', {bS, iC, iH, iW}, {-12.639999, 3.920004, 3.920000, 3.920000, 3.920002, 3.920000, 3.920000, 3.919998, 3.919998, 16.319998, 52.680004, 111.000015, 109.919991, 108.840004, 107.760002, 106.680008, 105.600006, 104.519997, 103.440018, 87.960007, 47.880001, 100.200005, 99.119995, 98.040001, 96.959999, 95.879990, 94.799995, 93.720001, 92.639999, 78.360001, 43.079998, 89.399994, 88.320007, 87.240005, 86.159996, 85.079994, 84.000000, 82.919998, 81.840004, 68.759995, 38.279999, 78.600006, 77.519997, 76.440010, 75.360001, 74.279999, 73.200005, 72.120003, 71.040001, 59.160004, 33.480000, 67.799995, 66.720009, 65.639999, 64.559998, 63.480000, 62.399994, 61.320007, 60.240002, 49.559998, 28.680004, 57.000004, 55.919998, 54.839993, 53.759998, 52.680000, 51.600002, 50.519997, 49.440002, 39.959999, 23.880001, 46.200001, 45.120003, 44.039997, 42.959999, 41.880001, 40.799999, 39.719994, 38.639999, 30.360001, 19.079998, 35.400002, 34.320000, 33.239998, 32.159996, 31.080000, 29.999998, 28.919998, 27.840000, 20.759998, 14.079999, 24.080000, 22.639997, 21.200001, 19.759998, 18.320002, 16.880001, 15.440001, 14.000000, 9.759999, 3.140000, 3.560000, 3.500000, 3.440000, 3.380000, 3.320000, 3.260000, 3.200000, 3.140000, -0.220000, 4.050000, 2.010000, 0.840000, -0.330000, -1.499999, -2.670000, -3.840000, -5.010000, -6.179998, -9.150000, -1.350000, -9.690001, -10.859999, -12.029998, -13.200001, -14.370001, -15.539999, -16.710001, -17.879999, -19.349998, -6.750000, -21.389997, -22.560003, -23.730003, -24.900002, -26.069998, -27.239998, -28.410007, -29.580002, -29.550003, -12.150001, -33.089996, -34.260002, -35.430000, -36.600002, -37.770000, -38.939995, -40.110001, -41.280003, -39.749996, -17.550003, -44.790005, -45.959991, -47.129993, -48.300003, -49.470001, -50.640003, -51.809990, -52.979996, -49.950001, -22.949999, -56.490005, -57.660000, -58.829998, -60.000000, -61.170002, -62.340004, -63.510002, -64.680000,
|
||||||
|
-60.149994, -28.349998, -68.189987, -69.360001, -70.529999, -71.700005, -72.870010, -74.039993, -75.209999, -76.379990, -70.349998, -33.749996, -79.889999, -81.059990, -82.229988, -83.399994, -84.570007, -85.740005, -86.910004, -88.079994, -80.549995, -69.340004, -125.080002, -126.580002, -128.080002, -129.580002, -131.080002, -132.580002, -134.080002, -135.580002, -105.979996, 10.919998, -8.799997, -8.919998, -9.040003, -9.160004, -9.279999, -9.400002, -9.520002, -9.640003, -24.760000, -56.580009, -124.980003, -126.240005, -127.499992, -128.759995, -130.020020, -131.279999, -132.540009, -133.800003, -118.260002, -62.580009, -137.580002, -138.840012, -140.099991, -141.360001, -142.620010, -143.879974, -145.139999, -146.399994, -129.060013, -68.580002, -150.179993, -151.439987, -152.699997, -153.959991, -155.219986, -156.480011, -157.740005, -159.000000, -139.860001, -74.579994, -162.779999, -164.040024, -165.300003, -166.560028, -167.819977, -169.080002, -170.339996, -171.599991, -150.660004, -80.580002, -175.379990, -176.639999, -177.899994, -179.160019, -180.419998, -181.679993, -182.940002, -184.199997, -161.459991, -86.580002, -187.979996, -189.240005, -190.499985, -191.759995, -193.020020, -194.279999, -195.540024, -196.800018, -172.260010, -92.580002, -200.579987, -201.839981, -203.100006, -204.359970, -205.620010, -206.880005, -208.139999, -209.399994, -183.060013, -98.580002, -213.180023, -214.440002, -215.700012, -216.959991, -218.220001, -219.480011, -220.739975, -222.000000, -193.860001, -160.760010, -286.239990, -287.799988, -289.360016, -290.920013, -292.480011, -294.040009, -295.599976, -297.160004, -229.719986, 10.700003, -33.160004, -33.339996, -33.519993, -33.700001,
|
||||||
|
-33.879997, -34.059994, -34.239994, -34.419994, -57.299995, -129.209991, -269.969971, -271.319977, -272.670044, -274.019989, -275.369995, -276.720001, -278.070007, -279.420013, -239.369980, -135.809998, -283.470001, -284.820007, -286.169983, -287.520020, -288.869995, -290.220001, -291.570038, -292.919983, -250.770004, -142.410004, -296.969971, -298.320007, -299.669983, -301.020020, -302.369995, -303.719971, -305.070007, -306.419983, -262.169983, -149.009995, -310.470001, -311.820007, -313.170013, -314.519989, -315.869995, -317.220001, -318.570007, -319.919983, -273.570007, -155.610016, -323.969971, -325.320038, -326.669983, -328.020020, -329.369965, -330.719971, -332.070007, -333.419983, -284.970001, -162.209991, -337.469971, -338.820007, -340.169983, -341.519958, -342.869995, -344.220001, -345.570007, -346.920013, -296.369995, -168.809998, -350.970001, -352.320007, -353.669983, -355.019989, -356.369995, -357.719971, -359.070038, -360.419983, -307.769989, -175.410004, -364.469971, -365.820007, -367.169983, -368.520020, -369.869995, -371.219971, -372.570007, -373.919983, -319.169983, -260.179993, -459.399994, -461.019958, -462.639984, -464.260010, -465.880005, -467.500000, -469.119995, -470.739990, -361.459991, 2.480003, -69.520004, -69.760025, -70.000000, -70.239990, -70.479996, -70.720001, -70.960007, -71.200005, -97.839996, -213.840012, -432.960022, -434.400055, -435.840027, -437.279999, -438.720001, -440.160065, -441.599976, -443.040039, -372.480011, -221.040009, -447.360016, -448.800018, -450.239990, -451.679993, -453.119995, -454.559967, -456.000061, -457.440033, -384.480011, -228.239990, -461.759979, -463.200012, -464.639984, -466.079956, -467.520081, -468.960052, -470.399963, -471.839996, -396.479980, -235.440002, -476.159912,
|
||||||
|
-477.600006, -479.040039, -480.479980, -481.919952, -483.360046, -484.800079, -486.239990, -408.480042, -242.639999, -490.559967, -491.999969, -493.440063, -494.880035, -496.319946, -497.759979, -499.200012, -500.639984, -420.480011, -249.840012, -504.960052, -506.399963, -507.839996, -509.280029, -510.720001, -512.159973, -513.599976, -515.040039, -432.480011, -257.040009, -519.360046, -520.800049, -522.239990, -523.680054, -525.120056, -526.559998, -527.999939, -529.440002, -444.480011, -264.239990, -533.760010, -535.200012, -536.640015, -538.079956, -539.520020, -540.960022, -542.399963, -543.839966, -456.479980, -367.599976, -644.559998, -646.239929, -647.920044, -649.599976, -651.280029, -652.960022, -654.640076, -656.320007, -501.200043, -13.740002, -117.880005, -118.179993, -118.479996, -118.780014, -119.080002, -119.379990, -119.680008, -119.979996, -146.379990, -310.470001, -613.950012, -615.479980, -617.010071, -618.539978, -620.069946, -621.599976, -623.130005, -624.660034, -517.589966, -318.269958, -629.250000, -630.779968, -632.309937, -633.840027, -635.369995, -636.899902, -638.429993, -639.959961, -530.190063, -326.070038, -644.550049, -646.079956, -647.609985, -649.140015, -650.669922, -652.200012, -653.729980, -655.260010, -542.789978, -333.870026, -659.849976, -661.380005, -662.910034, -664.439941, -665.970093, -667.500000, -669.029968, -670.559937, -555.390015, -341.669983, -675.149902, -676.679993, -678.209961, -679.740051, -681.270020, -682.800049, -684.329956, -685.859985, -567.989990, -349.470001, -690.450012, -691.979980, -693.510010, -695.039978, -696.569946, -698.099976, -699.630005, -701.160034, -580.589966, -357.269958, -705.750000, -707.279968, -708.809937, -710.340027, -711.869995, -713.399902, -714.929993, -716.459961, -593.190002, -365.070038, -721.050049, -722.579956, -724.109985, -725.640015, -727.169922, -728.700012,
|
||||||
|
-730.229980, -731.760010, -605.789978, -483.019958, -841.719971, -843.460022, -845.200073, -846.939941, -848.680054, -850.419983, -852.159973, -853.899963, -648.940002, -37.960014, -178.240021, -178.599976, -178.959991, -179.320007, -179.679993, -180.039978, -180.399994, -180.759964, -202.919983, -419.099915, -812.939941, -814.559937, -816.179993, -817.800049, -819.419922, -821.040039, -822.660034, -824.279968, -674.699951, -427.500031, -829.140015, -830.759949, -832.380005, -833.999939, -835.619995, -837.240051, -838.859924, -840.479980, -687.899963, -435.899994, -845.339966, -846.959961, -848.579956, -850.200012, -851.819885, -853.439941, -855.059937, -856.679993, -701.100037, -444.299927, -861.540039, -863.160034, -864.779968, -866.399963, -868.020020, -869.640015, -871.259949, -872.880005, -714.299988, -452.700012, -877.740051, -879.359924, -880.979980, -882.599915, -884.219971, -885.839966, -887.459961, -889.079956, -727.500000, -461.099915, -893.939941, -895.559937, -897.179993, -898.800049, -900.419922, -902.040039, -903.660034, -905.279968, -740.700012, -469.499969, -910.140015, -911.759949, -913.380005, -914.999939, -916.620056, -918.239990, -919.860046, -921.479919, -753.899963, -477.899902, -926.339905, -927.959961, -929.579956, -931.200012, -932.819946, -934.439880, -936.059937, -937.679932, -767.100037, -606.439941, -1050.880005, -1052.680054, -1054.479980, -1056.280029, -1058.079956, -1059.880005, -1061.679932, -1063.479980, -804.679993, -70.180008, -250.600006, -251.019958, -251.440033, -251.860001, -252.280029, -252.700043, -253.120026, -253.540039, -267.459991, -539.730042, -1029.929932, -1031.640137, -1033.350098, -1035.060059, -1036.770020, -1038.479980, -1040.190063, -1041.900024, -843.809998, -548.729980, -1047.030029, -1048.740112, -1050.449829, -1052.160034, -1053.870117, -1055.580078, -1057.289917, -1059.000122, -857.609985, -557.729980,
|
||||||
|
-1064.130005, -1065.840088, -1067.550049, -1069.260010, -1070.969849, -1072.679932, -1074.390137, -1076.100098, -871.410034, -566.729980, -1081.229980, -1082.940063, -1084.650024, -1086.359985, -1088.069946, -1089.780029, -1091.489990, -1093.199951, -885.210022, -575.729980, -1098.329956, -1100.040039, -1101.750122, -1103.460205, -1105.170166, -1106.879883, -1108.589966, -1110.300049, -899.010071, -584.730042, -1115.429932, -1117.140137, -1118.850098, -1120.560059, -1122.270020, -1123.979980, -1125.689941, -1127.400024, -912.810059, -593.730042, -1132.530029, -1134.240234, -1135.949951, -1137.659912, -1139.370117, -1141.079956, -1142.790039, -1144.500122, -926.610046, -602.730042, -1149.629883, -1151.339966, -1153.050049, -1154.760132, -1156.469971, -1158.179810, -1159.890137, -1161.600098, -940.410034, -737.859985, -1272.040039, -1273.899902, -1275.760010, -1277.619995, -1279.479980, -1281.340088, -1283.200195, -1285.060059, -968.420044}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expGradW('c', {kH, kW, iC, mC}, {-2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000,
|
||||||
|
-2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2594.701416, -2513.699951,
|
||||||
|
-18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -3043.501465, -2953.500244, -20863.500000, -56773.492188,
|
||||||
|
-110683.515625, -182593.515625, -272503.531250, -380413.562500, -3383.499756, -3283.500000, -23183.501953, -63083.500000, -122983.500000, -202883.515625,
|
||||||
|
-302783.531250, -422683.468750, -3043.501465, -2953.500244, -20863.500000, -56773.492188, -110683.515625, -182593.515625, -272503.531250, -380413.562500,
|
||||||
|
-2594.701416, -2513.699951, -18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expGradB('c', {oC}, {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::depthwise_conv2d_bp op;
|
||||||
|
ResultSet* results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
|
NDArray* gradI = results->at(0);
|
||||||
|
NDArray* gradW = results->at(1);
|
||||||
|
NDArray* gradB = results->at(2);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
||||||
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
||||||
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expGradB.isSameShape(gradB));
|
||||||
|
ASSERT_TRUE(expGradB.equalsTo(gradB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TYPED_TEST(TypedConvolutionTests1, conv3d_test1) {
|
TYPED_TEST(TypedConvolutionTests1, conv3d_test1) {
|
||||||
|
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -591,21 +591,6 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_7) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
|
|
||||||
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
|
|
||||||
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});
|
|
||||||
auto b = NDArrayFactory::create<float>('c', {1, 16});
|
|
||||||
auto grad = NDArrayFactory::create<float>('c', {4, 16, 64, 64});
|
|
||||||
|
|
||||||
auto gradI = in.like();
|
|
||||||
auto gradW = w.like();
|
|
||||||
auto gradB = b.like();
|
|
||||||
|
|
||||||
nd4j:ops::depthwise_conv2d_bp op;
|
|
||||||
auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, test_matmul_bp_1) {
|
TEST_F(DeclarableOpsTests15, test_matmul_bp_1) {
|
||||||
auto a = NDArrayFactory::create<double>('c', {1, 3});
|
auto a = NDArrayFactory::create<double>('c', {1, 3});
|
||||||
auto b = NDArrayFactory::create<double>('c', {1, 4});
|
auto b = NDArrayFactory::create<double>('c', {1, 4});
|
||||||
|
|
Loading…
Reference in New Issue