Shyrma weights format (#329)
* - start to introduce additional weights formats into conv2d ops Signed-off-by: Yurii <iuriish@yahoo.com> * - provide weights format variety in backprop conv2d and deconv2d ops, testing and fixing bugs Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to recover kernels sizes in deconv2d_bp test Signed-off-by: Yurii <iuriish@yahoo.com> * - built in weights format in depthwise conv 2d op Signed-off-by: Yurii <iuriish@yahoo.com> * - provide new weights formats in mkl dnn conv ops Signed-off-by: Yurii <iuriish@yahoo.com> * - provide new weights formats in cuda conv helpers Signed-off-by: Yurii <iuriish@yahoo.com> * - working with new weights format in cudnn conv api Signed-off-by: Yurii <iuriish@yahoo.com> * - take into account order of arrays in cudnn tensor descriptions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide new weights formats in cpu conv3d (ff/bp) Signed-off-by: Yurii <iuriish@yahoo.com> * - provide new weights formats in cpu deconv3d (ff/bp) Signed-off-by: Yurii <iuriish@yahoo.com> * - provide new weights formats in conv3d ops (ff/bp) based on mkl api Signed-off-by: Yurii <iuriish@yahoo.com> * - provide new weights formats in conv3d ops (ff/bp) based on cudnn api Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts 2 Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
5dae4069cf
commit
e700b59f80
|
@ -4076,7 +4076,7 @@ INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShap
|
|||
|
||||
// *** FIRST STAGE - exclude unity dimensions from oldShapeInfo and newShapeInfo (if such are present of course), since they don't affect on strides evaluation, however they complicate code
|
||||
|
||||
// FIXME - indeed we don't need to allocate so large memory amount (2*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities)
|
||||
// FIXME - indeed we don't need to allocate so large memory amount (4*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities)
|
||||
Nd4jLong tempBuffer[4*MAX_RANK];
|
||||
Nd4jLong *oldShape = tempBuffer, *newShape = tempBuffer + 2*MAX_RANK, *oldStrides, *newStrides;
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ namespace ops {
|
|||
CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
auto output = OUTPUT_NULLIFIED(0); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW)
|
||||
|
@ -45,12 +45,13 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
|
|||
int dW = INT_ARG(3); // dilations width
|
||||
int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL
|
||||
int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 0-NCW, 1-NWC
|
||||
int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC]
|
||||
|
||||
const int rank = 3;
|
||||
REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf());
|
||||
|
||||
int indIOioC, indIiW, indWoC(2);
|
||||
int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0);
|
||||
if(!isNCW) {
|
||||
indIOioC = 2; indIiW = 1;
|
||||
}
|
||||
|
@ -63,7 +64,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
|
|||
int iC = input->sizeAt(indIOioC); // input channels
|
||||
int oC = weights->sizeAt(indWoC); // output channels
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = 0 == wFormat ? std::vector<Nd4jLong>({kW, iC, oC}) : (1 == wFormat ? std::vector<Nd4jLong>({oC, iC, kW}) : std::vector<Nd4jLong>({oC, kW, iC}));
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
@ -83,11 +84,11 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
|
|||
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
|
||||
sd::ops::conv2d conv2d;
|
||||
const Nd4jStatus status = conv2d.execute({&inputReshaped, &weightsReshaped, bias}, {&outputReshaped}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
||||
const Nd4jStatus status = conv2d.execute({&inputReshaped, &weightsReshaped, bias}, {&outputReshaped}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW, wFormat}, {});
|
||||
if (status != ND4J_STATUS_OK)
|
||||
return status;
|
||||
|
||||
// ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW);
|
||||
// ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -105,8 +106,9 @@ DECLARE_SHAPE_FN(conv1d) {
|
|||
int dW = INT_ARG(3); // dilations width
|
||||
int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME
|
||||
int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW
|
||||
int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC]
|
||||
|
||||
int indIOioC, indIiW, indWoC(2);
|
||||
int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0);
|
||||
if(!isNCW) {
|
||||
indIOioC = 2; indIiW = 1;
|
||||
}
|
||||
|
@ -123,7 +125,7 @@ DECLARE_SHAPE_FN(conv1d) {
|
|||
int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = 0 == wFormat ? std::vector<Nd4jLong>({kW, iC, oC}) : (1 == wFormat ? std::vector<Nd4jLong>({oC, iC, kW}) : std::vector<Nd4jLong>({oC, kW, iC}));
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
@ -163,12 +165,12 @@ DECLARE_TYPES(conv1d) {
|
|||
CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon
|
||||
auto gradW = OUTPUT_NULLIFIED(1); // [kW, iC, oC] always
|
||||
auto gradW = OUTPUT_NULLIFIED(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
|
||||
|
||||
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width
|
||||
|
@ -177,12 +179,14 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
|||
int dW = INT_ARG(3); // dilations width
|
||||
int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL
|
||||
int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW
|
||||
int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC]
|
||||
|
||||
const int rank = 3;
|
||||
REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf());
|
||||
REQUIRE_TRUE(gradO->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradO->rankOf());
|
||||
int indIOioC, indIiW, indWoC(2);
|
||||
|
||||
int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0);
|
||||
if(!isNCW) {
|
||||
indIOioC = 2; indIiW = 1;
|
||||
}
|
||||
|
@ -199,7 +203,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
|||
ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = 0 == wFormat ? std::vector<Nd4jLong>({kW, iC, oC}) : (1 == wFormat ? std::vector<Nd4jLong>({oC, iC, kW}) : std::vector<Nd4jLong>({oC, kW, iC}));
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
|
@ -222,11 +226,11 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
|||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
|
||||
sd::ops::conv2d_bp conv2dBP;
|
||||
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
||||
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW, wFormat}, {});
|
||||
if (status != ND4J_STATUS_OK)
|
||||
return status;
|
||||
|
||||
// ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW);
|
||||
// ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -235,7 +239,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
|||
DECLARE_SHAPE_FN(conv1d_bp) {
|
||||
|
||||
auto inputShapeInfo = inputShape->at(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kW, iC, oC] always
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC]
|
||||
Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC]
|
||||
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next
|
||||
|
||||
|
@ -250,8 +254,9 @@ DECLARE_SHAPE_FN(conv1d_bp) {
|
|||
int dW = INT_ARG(3); // dilations width
|
||||
int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME
|
||||
int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW
|
||||
int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC]
|
||||
|
||||
int indIOioC, indIiW, indWoC(2);
|
||||
int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0);
|
||||
if(!isNCW) {
|
||||
indIOioC = 2; indIiW = 1;
|
||||
}
|
||||
|
@ -268,7 +273,7 @@ DECLARE_SHAPE_FN(conv1d_bp) {
|
|||
ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = 0 == wFormat ? std::vector<Nd4jLong>({kW, iC, oC}) : (1 == wFormat ? std::vector<Nd4jLong>({oC, iC, kW}) : std::vector<Nd4jLong>({oC, kW, iC}));
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if(biasShapeInfo)
|
||||
|
|
|
@ -37,7 +37,7 @@ namespace ops {
|
|||
CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||
|
@ -49,21 +49,22 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
|
|||
int dH = INT_ARG(6); // dilations height
|
||||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
|
||||
|
||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
|
||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -73,7 +74,7 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
|
|||
DECLARE_SHAPE_FN(conv2d) {
|
||||
|
||||
auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC] always
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
|
||||
|
||||
//output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||
|
@ -86,6 +87,7 @@ DECLARE_SHAPE_FN(conv2d) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
|
||||
|
||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) height
|
||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1)); // filter(kernel) width
|
||||
|
@ -95,7 +97,7 @@ DECLARE_SHAPE_FN(conv2d) {
|
|||
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]);
|
||||
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]);
|
||||
|
||||
int indIOioC, indIiH, indWoC(3);
|
||||
int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0);
|
||||
if(!isNCHW) {
|
||||
indIOioC = 3; indIiH = 1;
|
||||
}
|
||||
|
@ -109,7 +111,7 @@ DECLARE_SHAPE_FN(conv2d) {
|
|||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
const int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
@ -157,12 +159,12 @@ DECLARE_SHAPE_FN(conv2d) {
|
|||
CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC] always
|
||||
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
|
||||
|
||||
int kH = INT_ARG(0); // filter(kernel) height
|
||||
|
@ -175,6 +177,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
||||
|
@ -182,19 +185,19 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
std::vector<Nd4jLong>expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong>expectedWeightsShape = {kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong>expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -204,7 +207,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
|
|||
DECLARE_SHAPE_FN(conv2d_bp) {
|
||||
|
||||
auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC] always
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC]
|
||||
auto gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
|
||||
|
@ -224,8 +227,9 @@ DECLARE_SHAPE_FN(conv2d_bp) {
|
|||
const int dW = INT_ARG(7); // dilations width
|
||||
const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
|
||||
|
||||
int indIOioC, indIiH, indOoH, indWoC(3);
|
||||
int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 3 : 0);
|
||||
if(!isNCHW) {
|
||||
indIOioC = 3; indIiH = 1; indOoH = 1;
|
||||
}
|
||||
|
@ -243,7 +247,7 @@ DECLARE_SHAPE_FN(conv2d_bp) {
|
|||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if(biasShapeInfo)
|
||||
|
@ -264,7 +268,7 @@ DECLARE_SHAPE_FN(conv2d_bp) {
|
|||
CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
|
||||
|
||||
auto gradIShape = INPUT_VARIABLE(0); // [4]
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
|
@ -279,6 +283,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
|
||||
|
||||
const int rank = gradO->rankOf();
|
||||
|
||||
|
@ -295,17 +300,17 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
|
||||
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -321,7 +326,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
|
|||
DECLARE_SHAPE_FN(conv2d_input_bp) {
|
||||
|
||||
auto gradIShapeShapeInfo = inputShape->at(0); // [4]
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC] always
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
|
||||
const int rank = 4;
|
||||
|
@ -340,8 +345,9 @@ DECLARE_SHAPE_FN(conv2d_input_bp) {
|
|||
const int dW = INT_ARG(7); // dilations width
|
||||
const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
|
||||
|
||||
int indIOioC, indIiH, indWoC(3), indOoH;
|
||||
int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH;
|
||||
if(!isNCHW) {
|
||||
indIOioC = 3; indIiH = 1; indOoH = 1;
|
||||
}
|
||||
|
@ -361,7 +367,7 @@ DECLARE_SHAPE_FN(conv2d_input_bp) {
|
|||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace ops {
|
|||
|
||||
CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||
|
||||
|
@ -52,14 +52,15 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
|||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
@ -71,14 +72,24 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
|||
std::vector<int> permutForOutput;
|
||||
|
||||
if (isNCDHW)
|
||||
permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
|
||||
permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
|
||||
else
|
||||
input = new NDArray(input->permute({0,4,1,2,3}));
|
||||
|
||||
std::vector<int> wAxes;
|
||||
if(0 == wFormat)
|
||||
wAxes = {3,0,1,2};
|
||||
else if(1 == wFormat)
|
||||
wAxes = {1,2,3,4};
|
||||
else
|
||||
wAxes = {4,1,2,3};
|
||||
|
||||
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
|
||||
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||
// [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC]
|
||||
MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, {3,0,1,2}, permutForOutput);
|
||||
// [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, iC, kD, kH, kW] = [bS, oD, oH, oW, oC]
|
||||
// [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, kD, kH, kW, iC] = [bS, oD, oH, oW, oC]
|
||||
MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, wAxes, permutForOutput);
|
||||
|
||||
if(bias)
|
||||
// output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
|
||||
|
@ -101,7 +112,7 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
|||
DECLARE_SHAPE_FN(conv3dnew) {
|
||||
|
||||
auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC] always
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
|
||||
|
||||
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth
|
||||
|
@ -118,13 +129,14 @@ DECLARE_SHAPE_FN(conv3dnew) {
|
|||
int dW = INT_ARG(11); // dilations width
|
||||
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID;
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
|
||||
|
||||
const int rank = 5;
|
||||
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
|
||||
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo);
|
||||
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo);
|
||||
|
||||
int indIOioC, indIiD, indWoC(4);
|
||||
int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0);
|
||||
if(!isNCDHW) {
|
||||
indIOioC = 4; indIiD = 1;
|
||||
}
|
||||
|
@ -139,7 +151,7 @@ DECLARE_SHAPE_FN(conv3dnew) {
|
|||
int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
@ -174,12 +186,12 @@ DECLARE_SHAPE_FN(conv3dnew) {
|
|||
CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
|
@ -200,17 +212,18 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
|||
int dW = INT_ARG(11); // dilations width
|
||||
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
int trueoD, trueoH, trueoW; // true output depth/height/width
|
||||
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
|
||||
|
||||
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D_BP OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
|
@ -231,10 +244,25 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
|||
gradOaxesForDot = {0,2,3,4}; // bS, oD, oH, oW
|
||||
}
|
||||
|
||||
std::vector<int> wPermut, colPermut;
|
||||
|
||||
if(0 == wFormat) {
|
||||
wPermut = {3,0,1,2,4};
|
||||
colPermut = {2,3,4,1,0,5,6,7};
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
wPermut = {1,2,3,4,0};
|
||||
colPermut = {1,2,3,4,0,5,6,7};
|
||||
}
|
||||
else {
|
||||
wPermut = {4,1,2,3,0};
|
||||
colPermut = {2,3,4,1,0,5,6,7};
|
||||
}
|
||||
|
||||
// ----- calculation of gradW and gradB ----- //
|
||||
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
|
||||
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||
MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
|
||||
MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, wPermut); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
|
||||
|
||||
//----- calculation of gradO -----//
|
||||
if(gradB) {
|
||||
|
@ -246,7 +274,10 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
|||
}
|
||||
|
||||
//----- calculation of gradI -----//
|
||||
MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2,3,4,1,0,5,6,7}); // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
|
||||
// [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
|
||||
// [oC, iC, kD, kH, kW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
|
||||
// [oC, kD, kH, kW, iC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
|
||||
MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut);
|
||||
ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
|
||||
|
||||
if(!isNCDHW) {
|
||||
|
@ -270,7 +301,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
|||
DECLARE_SHAPE_FN(conv3dnew_bp) {
|
||||
|
||||
Nd4jLong* inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
Nd4jLong* weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC] always
|
||||
Nd4jLong* weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC]
|
||||
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
|
@ -288,6 +319,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
|
|||
int dW = INT_ARG(11); // dilations width
|
||||
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
|
||||
|
||||
const int rank = 5;
|
||||
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
|
||||
|
@ -295,7 +327,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
|
|||
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo);
|
||||
REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo);
|
||||
|
||||
int indIOioC, indIiD, indWoC(4);
|
||||
int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0);
|
||||
if(!isNCDHW) {
|
||||
indIOioC = 4; indIiD = 1;
|
||||
}
|
||||
|
@ -314,7 +346,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
|
|||
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if(biasShapeInfo)
|
||||
|
|
|
@ -35,7 +35,7 @@ namespace ops {
|
|||
CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||
|
@ -53,12 +53,13 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC);
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
@ -66,6 +67,12 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
|
|||
if(!isNCHW)
|
||||
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
|
||||
std::vector<int> colPermut;
|
||||
if(1 == wFormat)
|
||||
colPermut = {1, 2, 3, 0, 4, 5};
|
||||
else
|
||||
colPermut = {2, 3, 1, 0, 4, 5};
|
||||
|
||||
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
|
@ -73,8 +80,9 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
|
|||
|
||||
//----- calculation of output -----//
|
||||
// NHWC: [kH, kW, oC, iC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW]
|
||||
// NCHW: [kH, kW, oC, iC] x [bS, iC, iH, iW] = [kH, kW, oC, bS, iH, iW]
|
||||
sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 1, 0, 4, 5});
|
||||
// NHWC: [iC, oC, kH, kW] x [bS, iH, iW, iC] = [oC, kH, kW, bS, iH, iW]
|
||||
// NHWC: [iC, kH, kW, oC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW]
|
||||
sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, colPermut);
|
||||
LaunchContext* ctx = block.launchContext();
|
||||
helpers::col2im(*ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, dW); // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW]
|
||||
|
||||
|
@ -97,7 +105,7 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
|
|||
DECLARE_SHAPE_FN(deconv2d) {
|
||||
|
||||
auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC] always
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
|
||||
auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
|
||||
|
||||
const int rank = 4;
|
||||
|
@ -114,8 +122,9 @@ DECLARE_SHAPE_FN(deconv2d) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
|
||||
|
||||
int indIOioC, indIiH, indWoC(2);
|
||||
int indIOioC, indIiH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3));
|
||||
if(!isNCHW) {
|
||||
indIOioC = 3; indIiH = 1;
|
||||
}
|
||||
|
@ -129,7 +138,7 @@ DECLARE_SHAPE_FN(deconv2d) {
|
|||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
const int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC);
|
||||
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
@ -163,12 +172,12 @@ DECLARE_SHAPE_FN(deconv2d) {
|
|||
CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||
|
@ -186,16 +195,17 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
|
||||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
|
@ -206,29 +216,34 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
|||
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||
}
|
||||
|
||||
|
||||
// ----- calculation of gradI -> pass it through conv2d_ff ----- //
|
||||
// ----- calculation of gradI -> pass it through conv2d_ff ----- //
|
||||
sd::ops::conv2d conv2d;
|
||||
const Nd4jStatus status = conv2d.execute({gradO, weights}, {gradI}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, !isNCHW}, {});
|
||||
const Nd4jStatus status = conv2d.execute({gradO, weights}, {gradI}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, !isNCHW, wFormat}, {});
|
||||
if (status != ND4J_STATUS_OK)
|
||||
return status;
|
||||
|
||||
// -----prepare permutation arrays and axes for dot product ----- //
|
||||
std::vector<int> inputAxesForDot;
|
||||
std::vector<int> inputAxes;
|
||||
|
||||
if(!isNCHW) {
|
||||
gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
inputAxesForDot = {0, 1, 2}; // bS, iH, iW
|
||||
inputAxes = {0, 1, 2}; // bS, iH, iW
|
||||
}
|
||||
else
|
||||
inputAxesForDot = {0, 2, 3}; // bS, iH, iW
|
||||
inputAxes = {0, 2, 3}; // bS, iH, iW
|
||||
|
||||
std::vector<int> gradWAxes; // empty for wFormat = 1
|
||||
if(0 == wFormat)
|
||||
gradWAxes = {3, 2, 0, 1};
|
||||
else if(2 == wFormat)
|
||||
gradWAxes = {0, 3, 1, 2};
|
||||
|
||||
// ----- calculation of gradW ----- //
|
||||
NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext());
|
||||
|
||||
LaunchContext* ctx = block.launchContext();
|
||||
helpers::im2col(*ctx, *gradO, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, oC, oH, oW] is convoluted to [bS, oC, kH, kW, iH, iW]
|
||||
MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 4, 5}, {3, 2, 0, 1}); // [bS, iC, iH, iW]/[bS, iH, iW, iC] x [bS, oC, kH, kW, iH, iW] = [iC, oC, kH, kW]
|
||||
MmulHelper::tensorDot(input, &columns, gradW, inputAxes, {0, 4, 5}, gradWAxes); // [bS, iC, iH, iW]/[bS, iH, iW, iC] x [bS, oC, kH, kW, iH, iW] = [iC, oC, kH, kW]
|
||||
|
||||
// ----- calculation of gradB ----- //
|
||||
if(gradB) {
|
||||
|
@ -248,7 +263,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
|||
DECLARE_SHAPE_FN(deconv2d_bp) {
|
||||
|
||||
auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC] always
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
|
||||
Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC]
|
||||
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
|
@ -267,8 +282,9 @@ DECLARE_SHAPE_FN(deconv2d_bp) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
|
||||
|
||||
int indIOioC, indIiH, indWoC(2), indOoH;
|
||||
int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3));
|
||||
if(!isNCHW) {
|
||||
indIOioC = 3; indIiH = 1; indOoH = 1;
|
||||
}
|
||||
|
@ -286,7 +302,7 @@ DECLARE_SHAPE_FN(deconv2d_bp) {
|
|||
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC);
|
||||
REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if(biasShapeInfo)
|
||||
|
|
|
@ -32,10 +32,10 @@ namespace ops {
|
|||
CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
|
||||
|
||||
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI)
|
||||
|
||||
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
|
||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
|
||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
|
||||
|
@ -47,6 +47,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
|
||||
|
||||
const int rank = gradO->rankOf();
|
||||
|
||||
|
@ -57,20 +58,19 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
|
|||
// create empty conv2d input array
|
||||
NDArray input(gradO->ordering(), gradIShape->asVectorT<Nd4jLong>(), gradO->dataType(), block.launchContext());
|
||||
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
|
||||
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -84,7 +84,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
|
|||
DECLARE_SHAPE_FN(deconv2d_tf) {
|
||||
|
||||
auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC] always
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto gradIShapeShapeInfo = inputShape->at(0); // [4]
|
||||
|
||||
const int rank = 4;
|
||||
|
@ -103,8 +103,9 @@ DECLARE_SHAPE_FN(deconv2d_tf) {
|
|||
const int dW = INT_ARG(7); // dilations width
|
||||
const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||
const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
|
||||
|
||||
int indIOioC, indIiH, indWoC(3), indOoH;
|
||||
int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH;
|
||||
if(!isNCHW) {
|
||||
indIOioC = 3; indIiH = 1; indOoH = 1;
|
||||
}
|
||||
|
@ -126,7 +127,7 @@ DECLARE_SHAPE_FN(deconv2d_tf) {
|
|||
ConvolutionUtils::calcOutSizeDeconv2D(trueiH, trueiW, kH, kW, sH, sW, pH, pW, dH, dW, oH, oW, isSameMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,trueiH,trueiW, 0,indIOioC,indIiH,indIiH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(expectedGradIShape == gradIShape, 0, "CUSTOM DECONV2D_TF OP: wrong shape of array with output shape, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradIShape).c_str());
|
||||
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace ops {
|
|||
CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||
|
@ -53,13 +53,14 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
|
|||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC);
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
@ -67,16 +68,23 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
|
|||
if(!isNCDHW)
|
||||
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||
|
||||
std::vector<int> colPermut;
|
||||
if(1 == wFormat)
|
||||
colPermut = {1,2,3,4,0,5,6,7};
|
||||
else
|
||||
colPermut = {2,3,4,1,0,5,6,7};
|
||||
|
||||
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
NDArray columns(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
|
||||
|
||||
//----- calculation of output -----//
|
||||
// NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||
// NCDHW: [kD, kH, kW, oC, iC] x [bS, iC, iD, iH, iW] = [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||
sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||
ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
|
||||
// [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||
// [iC, oC, kD, kH, kW] x [bS, iD, iH, iW, iC] = [oC, kD, kH, kW, bS, iD, iH, iW]
|
||||
// [iC, kD, kH, kW, oC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||
sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, colPermut); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||
ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
|
||||
|
||||
//----- add biases if required -----//
|
||||
if(bias)
|
||||
|
@ -101,7 +109,7 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
|
|||
DECLARE_SHAPE_FN(deconv3d) {
|
||||
|
||||
auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NDCHW)
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC] always
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
|
||||
auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
|
||||
|
||||
const int rank = 5;
|
||||
|
@ -122,8 +130,9 @@ DECLARE_SHAPE_FN(deconv3d) {
|
|||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
|
||||
|
||||
int indIOioC, indIiD, indWoC(3);
|
||||
int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4));
|
||||
if(!isNCDHW) {
|
||||
indIOioC = 4; indIiD = 1;
|
||||
}
|
||||
|
@ -138,7 +147,7 @@ DECLARE_SHAPE_FN(deconv3d) {
|
|||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
const int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC);
|
||||
REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo));
|
||||
|
@ -174,12 +183,12 @@ DECLARE_SHAPE_FN(deconv3d) {
|
|||
CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
|
@ -201,16 +210,17 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
|||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
|
||||
int trueoD, trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
|
@ -221,7 +231,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
|||
|
||||
// ----- calculation of gradI -> pass it through conv3d_ff ----- //
|
||||
sd::ops::conv3dnew conv3d;
|
||||
const Nd4jStatus status = conv3d.execute({gradO, weights}, {gradI}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, isSameMode, !isNCDHW}, {});
|
||||
const Nd4jStatus status = conv3d.execute({gradO, weights}, {gradI}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, isSameMode, !isNCDHW, wFormat}, {});
|
||||
if (status != ND4J_STATUS_OK)
|
||||
return status;
|
||||
|
||||
|
@ -235,10 +245,16 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
|||
else
|
||||
inputAxesForDot = {0, 2, 3, 4}; // bS, iD, iH, iW
|
||||
|
||||
std::vector<int> gradWAxes; // empty for wFormat = 1
|
||||
if(0 == wFormat)
|
||||
gradWAxes = {4,3,0,1,2};
|
||||
else if(2 == wFormat)
|
||||
gradWAxes = {0,4,1,2,3};
|
||||
|
||||
// ----- calculation of gradW ----- //
|
||||
auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
|
||||
ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW]
|
||||
MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, {4, 3, 0, 1, 2}); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW]
|
||||
ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW]
|
||||
MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, gradWAxes); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW]
|
||||
|
||||
// ----- calculation of gradB ----- //
|
||||
if(gradB) {
|
||||
|
@ -267,7 +283,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
|||
DECLARE_SHAPE_FN(deconv3d_bp) {
|
||||
|
||||
auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC] always
|
||||
auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
|
||||
Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC]
|
||||
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
|
@ -290,8 +306,9 @@ DECLARE_SHAPE_FN(deconv3d_bp) {
|
|||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
|
||||
|
||||
int indIOioC, indIiD, indWoC(3);
|
||||
int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4));
|
||||
if(!isNCDHW) {
|
||||
indIOioC = 4; indIiD = 1;
|
||||
}
|
||||
|
@ -310,8 +327,8 @@ DECLARE_SHAPE_FN(deconv3d_bp) {
|
|||
ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
|
||||
REQUIRE_TRUE(shape::shapeEquals(5, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC);
|
||||
REQUIRE_TRUE(shape::shapeEquals(5, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if(biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace ops {
|
|||
CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC
|
||||
|
||||
auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
|
||||
|
@ -50,19 +50,20 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
|
||||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC);
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -75,7 +76,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) {
|
|||
DECLARE_SHAPE_FN(depthwise_conv2d) {
|
||||
|
||||
Nd4jLong* inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
Nd4jLong* weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, mC] always
|
||||
Nd4jLong* weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] = iC*mC
|
||||
|
||||
const int rank = 4;
|
||||
|
@ -92,8 +93,9 @@ DECLARE_SHAPE_FN(depthwise_conv2d) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
|
||||
|
||||
int indIOioC, indIiH, indWmC(3);
|
||||
int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0);
|
||||
if(!isNCHW) {
|
||||
indIOioC = 3; indIiH = 1;
|
||||
}
|
||||
|
@ -109,7 +111,7 @@ DECLARE_SHAPE_FN(depthwise_conv2d) {
|
|||
const int oC = iC*mC; // output channels
|
||||
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
|
||||
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo));
|
||||
|
@ -148,12 +150,12 @@ DECLARE_SHAPE_FN(depthwise_conv2d) {
|
|||
CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
|
||||
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always
|
||||
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||
|
@ -170,23 +172,24 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
|
||||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
int trueoH, trueoW; // correct output height, width
|
||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -214,8 +217,9 @@ DECLARE_SHAPE_FN(depthwise_conv2d_bp) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
|
||||
|
||||
int indIOioC, indIiH, indWmC(3);
|
||||
int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0);
|
||||
if(!isNCHW) {
|
||||
indIOioC = 3; indIiH = 1;
|
||||
}
|
||||
|
@ -234,7 +238,7 @@ DECLARE_SHAPE_FN(depthwise_conv2d_bp) {
|
|||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
|
||||
REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if(biasShapeInfo)
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace ops {
|
|||
CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [1, 1, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, iH, iW, oC] (NHWC) or [bS, oC, iH, iW] (NCHW)
|
||||
|
@ -47,18 +47,19 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
|
|||
int pW = 0; // paddings width
|
||||
int dH = 1; // dilations height
|
||||
int dW = 1; // dilations width
|
||||
int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC
|
||||
int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 1 ? INT_ARG(1) : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC]
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {1, 1, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC);
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW);
|
||||
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -73,7 +74,7 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
|
|||
DECLARE_SHAPE_FN(pointwise_conv2d) {
|
||||
|
||||
Nd4jLong* inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
Nd4jLong* weightsShapeInfo = inputShape->at(1); // [1, 1, iC, oC] always
|
||||
Nd4jLong* weightsShapeInfo = inputShape->at(1); // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC]
|
||||
Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
|
||||
|
||||
const int rank = 4;
|
||||
|
@ -81,8 +82,9 @@ DECLARE_SHAPE_FN(pointwise_conv2d) {
|
|||
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]);
|
||||
|
||||
int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 1 ? INT_ARG(1) : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC]
|
||||
|
||||
int indIOioC, indWoC(3);
|
||||
int indIOioC, indWoC(0 == wFormat ? 3 : 0);
|
||||
if(!isNCHW)
|
||||
indIOioC = 3;
|
||||
else
|
||||
|
@ -92,7 +94,7 @@ DECLARE_SHAPE_FN(pointwise_conv2d) {
|
|||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
const int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {1, 1, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC);
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
|
|
@ -33,8 +33,8 @@ namespace ops {
|
|||
CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
|
||||
|
||||
NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
NDArray *weightsDepth = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always
|
||||
NDArray *weightsDepth = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
|
||||
NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC
|
||||
|
||||
NDArray *output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||
|
@ -66,17 +66,19 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
|
|||
int dH = INT_ARG(6); // dilations height
|
||||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
|
||||
|
||||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weightsDepth->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC};
|
||||
std::vector<Nd4jLong> expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
|
||||
REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str());
|
||||
if(weightsPoint) {
|
||||
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC);
|
||||
REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str());
|
||||
}
|
||||
if (bias)
|
||||
|
@ -84,11 +86,11 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
|
|||
|
||||
if (iC == 1) {
|
||||
nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n","");
|
||||
ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -103,8 +105,8 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
|
|||
DECLARE_SHAPE_FN(sconv2d) {
|
||||
|
||||
auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weightsDShapeInfo = inputShape->at(1); // [kH, kW, iC, mC] always
|
||||
Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC] always
|
||||
auto weightsDShapeInfo = inputShape->at(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
|
||||
Nd4jLong* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr
|
||||
|
||||
if(block.width() == 3)
|
||||
|
@ -135,8 +137,9 @@ DECLARE_SHAPE_FN(sconv2d) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
|
||||
|
||||
int indIOioC, indIiH, indWmC(3);
|
||||
int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0);
|
||||
if(!isNCHW) {
|
||||
indIOioC = 3; indIiH = 1;
|
||||
}
|
||||
|
@ -148,13 +151,13 @@ DECLARE_SHAPE_FN(sconv2d) {
|
|||
const int iH = inputShapeInfo[indIiH+1]; // input height
|
||||
const int iW = inputShapeInfo[indIiH+2]; // input width
|
||||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
const int mC = weightsDShapeInfo[indWmC+1]; // channel multiplier
|
||||
const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC+1] : iC*mC; // output channels (oC or iC*mC)
|
||||
const int mC = weightsDShapeInfo[indWmC+1]; // channel multiplier
|
||||
const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC+1] : iC*mC; // output channels (oC or iC*mC)
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC};
|
||||
std::vector<Nd4jLong> expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str());
|
||||
if(weightsPShapeInfo) {
|
||||
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC);
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str());
|
||||
}
|
||||
if (biasShapeInfo)
|
||||
|
@ -195,13 +198,13 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
|
|||
|
||||
NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
NDArray *gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
NDArray *weightsDepth = INPUT_VARIABLE(2); // [kH, kW, iC, mC] always
|
||||
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always
|
||||
NDArray *weightsDepth = INPUT_VARIABLE(2); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
|
||||
NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr
|
||||
|
||||
NDArray *gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
NDArray *gradWD = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always
|
||||
NDArray *gradWP = nullptr; // [1, 1, iC*mC, oC] always
|
||||
NDArray *gradWD = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
NDArray *gradWP = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
|
||||
NDArray *gradB = nullptr; // [oC]
|
||||
|
||||
if(block.width() == 4) {
|
||||
|
@ -244,17 +247,18 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
|
||||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weightsDepth->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC};
|
||||
std::vector<Nd4jLong> expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
|
||||
REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str());
|
||||
REQUIRE_TRUE(gradWD->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of gradWD array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(gradWD).c_str());
|
||||
if(weightsPoint) {
|
||||
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC);
|
||||
REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str());
|
||||
REQUIRE_TRUE(gradWP->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of gradWP array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(gradWP).c_str());
|
||||
}
|
||||
|
@ -274,12 +278,12 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
|
|||
|
||||
auto resultFFShape = isNCHW ? std::vector<Nd4jLong>({bS, mC*iC, oH, oW}) : std::vector<Nd4jLong>({bS, oH, oW, mC*iC});
|
||||
auto resultFF = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext());
|
||||
ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat);
|
||||
|
||||
auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||
auto gradIDepth = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
|
||||
|
||||
ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH and oW=iW
|
||||
ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW, wFormat); // in this case oH=iH and oW=iW
|
||||
|
||||
gradO = gradIDepth;
|
||||
bias = gradB = nullptr; // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step
|
||||
|
@ -288,7 +292,7 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
|
|||
}
|
||||
|
||||
// ----- apply depthwise_conv2d_bp ----- //
|
||||
ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat);
|
||||
|
||||
if(weightsPoint)
|
||||
delete gradO;
|
||||
|
@ -301,8 +305,8 @@ DECLARE_SHAPE_FN(sconv2d_bp) {
|
|||
|
||||
auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto gradOShapeInfo = inputShape->at(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto weightsDShapeInfo = inputShape->at(2); // [kH, kW, iC, mC] always
|
||||
Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC] always
|
||||
auto weightsDShapeInfo = inputShape->at(2); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
|
||||
Nd4jLong* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr
|
||||
|
||||
if(block.width() == 4) {
|
||||
|
@ -335,8 +339,9 @@ DECLARE_SHAPE_FN(sconv2d_bp) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
|
||||
|
||||
int indIOioC, indIiH, indWmC(3);
|
||||
int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0);
|
||||
if(!isNCHW) {
|
||||
indIOioC = 3; indIiH = 1;
|
||||
}
|
||||
|
@ -356,10 +361,10 @@ DECLARE_SHAPE_FN(sconv2d_bp) {
|
|||
|
||||
std::vector<Nd4jLong> expectedGradOShapeInfo = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1});
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShapeInfo), 0, "SCONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShapeInfo).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC};
|
||||
std::vector<Nd4jLong> expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D_BP OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str());
|
||||
if(weightsPShapeInfo) {
|
||||
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC);
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D_BP OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str());
|
||||
}
|
||||
if (biasShapeInfo)
|
||||
|
|
|
@ -166,7 +166,7 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
|
||||
|
|
|
@ -55,7 +55,7 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
|
|||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
|
@ -172,7 +172,7 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
|
|||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
|
|
|
@ -168,7 +168,7 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
|
||||
|
|
|
@ -55,7 +55,7 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
|
|||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
|
@ -174,7 +174,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
|
|||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
|
|
|
@ -167,7 +167,7 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
|
||||
|
|
|
@ -154,15 +154,24 @@ namespace sd {
|
|||
}
|
||||
|
||||
// evaluates sizes values and indexes using input and output arrays depending on data format
|
||||
static inline void getSizesAndIndexesConv2d(const bool isNCHW, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) {
|
||||
getSizesAndIndexesConv2d(isNCHW, input.getShapeInfo(), output.getShapeInfo(), bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
static inline void getSizesAndIndexesConv2d(const bool isNCHW, const int wFormat, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) {
|
||||
getSizesAndIndexesConv2d(isNCHW, wFormat, input.getShapeInfo(), output.getShapeInfo(), bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
}
|
||||
|
||||
static inline void getSizesAndIndexesConv2d(const bool isNCHW, const Nd4jLong* inShapeInfo, const Nd4jLong* outShapeInfo, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) {
|
||||
static inline void getSizesAndIndexesConv2d(const bool isNCHW, const int wFormat, const Nd4jLong* inShapeInfo, const Nd4jLong* outShapeInfo, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) {
|
||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
// weights [kH, kW, iC, oC] always
|
||||
// weights [kH, kW, iC, oC] (wFormat = 0), [oC, iC, kH, kW] (wFormat = 1), [oC, kH, kW, iC] (wFormat = 2)
|
||||
// output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||
indWkH = 0; indWiC = 2; indWoC = 3;
|
||||
|
||||
if(0 == wFormat) {
|
||||
indWkH = 0; indWiC = 2; indWoC = 3;
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
indWkH = 2; indWiC = 1; indWoC = 0;
|
||||
}
|
||||
else {
|
||||
indWkH = 1; indWiC = 3; indWoC = 0;
|
||||
}
|
||||
|
||||
if(!isNCHW) {
|
||||
indIOioC = 3; indIiH = 1; indOoH = 1;
|
||||
|
@ -181,12 +190,21 @@ namespace sd {
|
|||
}
|
||||
|
||||
// evaluates sizes values and indexes using input and output arrays depending on data format
|
||||
static inline void getSizesAndIndexesConv3d(const bool isNCDHW, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iD, int& iH, int& iW, int& oC, int& oD, int& oH, int& oW, int& indIOioC, int& indIOioD, int& indWiC, int& indWoC, int& indWkD) {
|
||||
static inline void getSizesAndIndexesConv3d(const bool isNCDHW, const int wFormat, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iD, int& iH, int& iW, int& oC, int& oD, int& oH, int& oW, int& indIOioC, int& indIOioD, int& indWiC, int& indWoC, int& indWkD) {
|
||||
// input [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
// weights [kD, kH, kW, iC, oC] (NDHWC) or [oC, iC, kD, kH, kW] (NCDHW)
|
||||
// weights [kD, kH, kW, iC, oC] (wFormat = 0), [oC, iC, kD, kH, kW] (wFormat = 1), [oC, kD, kH, kW, iC] (wFormat = 2)
|
||||
// output [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||
|
||||
indWkD = 0; indWiC = 3; indWoC = 4;
|
||||
if(0 == wFormat) {
|
||||
indWkD = 0; indWiC = 3; indWoC = 4;
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
indWkD = 2; indWiC = 1; indWoC = 0;
|
||||
}
|
||||
else {
|
||||
indWkD = 1; indWiC = 4; indWoC = 0;
|
||||
}
|
||||
|
||||
if(!isNCDHW) {
|
||||
indIOioC = 4; indIOioD = 1;
|
||||
}
|
||||
|
@ -203,7 +221,6 @@ namespace sd {
|
|||
oD = output.sizeAt(indIOioD); // output depth
|
||||
oH = output.sizeAt(indIOioD+1); // output height
|
||||
oW = output.sizeAt(indIOioD+2); // output width
|
||||
|
||||
}
|
||||
|
||||
// static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int paddingMode, int& pH, int& pW, int& dH, int& dW) {
|
||||
|
@ -254,19 +271,41 @@ namespace sd {
|
|||
// }
|
||||
// }
|
||||
|
||||
static void conv2d(sd::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW);
|
||||
static std::vector<Nd4jLong> expectWeightsShape(const int wFormat, const int kH, const int kW, const int iC, const int oC) {
|
||||
|
||||
if(0 == wFormat)
|
||||
return std::vector<Nd4jLong>({kH, kW, iC, oC});
|
||||
|
||||
if(1 == wFormat)
|
||||
return std::vector<Nd4jLong>({oC, iC, kH, kW});
|
||||
|
||||
return std::vector<Nd4jLong>({oC, kH, kW, iC});
|
||||
}
|
||||
|
||||
static std::vector<Nd4jLong> expectWeightsShape(const int wFormat, const int kD, const int kH, const int kW, const int iC, const int oC) {
|
||||
|
||||
if(0 == wFormat)
|
||||
return std::vector<Nd4jLong>({kD, kH, kW, iC, oC});
|
||||
|
||||
if(1 == wFormat)
|
||||
return std::vector<Nd4jLong>({oC, iC, kD, kH, kW});
|
||||
|
||||
return std::vector<Nd4jLong>({oC, kD, kH, kW, iC});
|
||||
}
|
||||
|
||||
static void conv2d(sd::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat);
|
||||
|
||||
// static void conv2d(sd::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
|
||||
|
||||
// static void conv2dBP(sd::graph::Context & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs);
|
||||
|
||||
static void conv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW);
|
||||
static void conv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat);
|
||||
|
||||
static void depthwiseConv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW);
|
||||
static void depthwiseConv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat);
|
||||
|
||||
static void depthwiseConv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW);
|
||||
static void depthwiseConv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat);
|
||||
|
||||
static void sconv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW);
|
||||
static void sconv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat);
|
||||
|
||||
static void vol2col(sd::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
|
||||
|
||||
|
|
|
@ -258,10 +258,10 @@ namespace sd {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
// weights [kH, kW, iC, oC] always
|
||||
// weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
// bias [oC]
|
||||
// output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||
|
||||
|
@ -278,7 +278,7 @@ namespace sd {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
|
@ -291,6 +291,14 @@ namespace sd {
|
|||
else
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC
|
||||
|
||||
std::vector<int> wAxes;
|
||||
if(0 == wFormat)
|
||||
wAxes = {0, 1, 2};
|
||||
else if(1 == wFormat)
|
||||
wAxes = {2, 3, 1};
|
||||
else
|
||||
wAxes = {1, 2, 3};
|
||||
|
||||
NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext());
|
||||
NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW}
|
||||
NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext());
|
||||
|
@ -298,7 +306,7 @@ namespace sd {
|
|||
//----- calculation of output -----//
|
||||
auto ctx = block.launchContext();
|
||||
helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, {0,1,2}, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC]
|
||||
MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC]
|
||||
|
||||
//----- assign outTemp to output -----//
|
||||
if(isNCHW) {
|
||||
|
@ -319,15 +327,15 @@ namespace sd {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
// weights [kH, kW, iC, oC] always
|
||||
// weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
// bias [oC]
|
||||
// gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
|
||||
// gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
// gradW [kH, kW, iC, oC] always
|
||||
// gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
// gradB [oC]
|
||||
|
||||
// kH filter(kernel) height
|
||||
|
@ -343,7 +351,7 @@ namespace sd {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
|
@ -359,13 +367,28 @@ namespace sd {
|
|||
gradOaxesForDot = {0, 2, 3}; // bS, oH, oW
|
||||
}
|
||||
|
||||
std::vector<int> wPermut, colPermut;
|
||||
|
||||
if(0 == wFormat) {
|
||||
wPermut = {2, 0, 1, 3};
|
||||
colPermut = {2, 3, 1, 0, 4, 5};
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
wPermut = {1, 2, 3, 0};
|
||||
colPermut = {1, 2, 3, 0, 4, 5};
|
||||
}
|
||||
else {
|
||||
wPermut = {3, 1, 2, 0};
|
||||
colPermut = {2, 3, 1, 0, 4, 5};
|
||||
}
|
||||
|
||||
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
||||
|
||||
// ----- calculation of gradW ----- //
|
||||
if(gradW) {
|
||||
auto ctx = block.launchContext();
|
||||
helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, {2, 0, 1, 3}); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC]
|
||||
sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC]
|
||||
}
|
||||
|
||||
// ----- calculation of gradB ----- //
|
||||
|
@ -379,9 +402,12 @@ namespace sd {
|
|||
}
|
||||
|
||||
//----- calculation of gradI -----//
|
||||
sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2, 3, 1, 0, 4, 5}); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
|
||||
// [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
|
||||
// [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW]
|
||||
// [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
|
||||
sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut);
|
||||
|
||||
helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
|
||||
helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
|
||||
|
||||
if(!isNCHW) {
|
||||
delete input;
|
||||
|
@ -391,10 +417,10 @@ namespace sd {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
// weights [kH, kW, iC, mC] always
|
||||
// weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
// bias [oC] = iC*mC
|
||||
// output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
|
||||
|
||||
|
@ -411,23 +437,30 @@ namespace sd {
|
|||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
std::vector<std::vector<Nd4jLong>> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW]
|
||||
std::vector<std::vector<Nd4jLong>> modifOutput;
|
||||
std::vector<std::vector<Nd4jLong>> modifOutput, modifWeights;
|
||||
std::vector<Nd4jLong> outReShape;
|
||||
|
||||
if(!isNCHW) {
|
||||
outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC]
|
||||
modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
|
||||
}
|
||||
else {
|
||||
outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW]
|
||||
modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
|
||||
}
|
||||
|
||||
if(0 == wFormat)
|
||||
modifWeights = {{2,0,1,3},{iC,kH*kW,mC}};
|
||||
else if(1 == wFormat)
|
||||
modifWeights = {{1,2,3,0},{iC,kH*kW,mC}};
|
||||
else
|
||||
modifWeights = {{3,1,2,0},{iC,kH*kW,mC}};
|
||||
|
||||
if(paddingMode == 1) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
|
@ -435,7 +468,7 @@ namespace sd {
|
|||
NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false);
|
||||
|
||||
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
|
||||
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
|
||||
|
||||
if(bias)
|
||||
// output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
|
||||
|
@ -447,14 +480,14 @@ namespace sd {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||
// weights [kH, kW, iC, mC] always
|
||||
// weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
// bias [oC] = [iC*mC]
|
||||
// gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||
// gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
|
||||
// gradW [kH, kW, iC, mC] always
|
||||
// gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
// gradB [oC]
|
||||
|
||||
// kH filter(kernel) height
|
||||
|
@ -470,19 +503,19 @@ namespace sd {
|
|||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
std::vector<std::vector<Nd4jLong>> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW]
|
||||
std::vector<std::vector<Nd4jLong>> modifGradO1, modifGradO2;
|
||||
std::vector<std::vector<Nd4jLong>> modifGradO1, modifGradO2, modifWeights;
|
||||
std::vector<Nd4jLong> gradOreShape;
|
||||
|
||||
if(!isNCHW) {
|
||||
gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC]
|
||||
modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
|
||||
modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
|
||||
gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
|
||||
gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
|
||||
}
|
||||
else {
|
||||
gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW]
|
||||
|
@ -490,6 +523,13 @@ namespace sd {
|
|||
modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
|
||||
}
|
||||
|
||||
if(0 == wFormat)
|
||||
modifWeights = {{2,0,1,3},{iC,kH*kW,mC}};
|
||||
else if(1 == wFormat)
|
||||
modifWeights = {{1,2,3,0},{iC,kH*kW,mC}};
|
||||
else
|
||||
modifWeights = {{3,1,2,0},{iC,kH*kW,mC}};
|
||||
|
||||
if(paddingMode == 1) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
|
@ -499,7 +539,7 @@ namespace sd {
|
|||
// ----- calculation of gradW and gradB ----- //
|
||||
|
||||
helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, {{2,0,1,3},{iC,kH*kW,mC}}); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC]
|
||||
sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC]
|
||||
|
||||
// ----- calculation of gradB ----- //
|
||||
if(gradB) {
|
||||
|
@ -513,8 +553,8 @@ namespace sd {
|
|||
}
|
||||
|
||||
//----- calculation of gradI -----//
|
||||
sd::MmulHelper::tensorDot(weights, gradO, &columns, {{2,0,1,3},{iC,kH*kW,mC}}, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW]
|
||||
helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
|
||||
sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW]
|
||||
helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
|
||||
|
||||
if(!isNCHW) {
|
||||
delete input;
|
||||
|
@ -524,11 +564,11 @@ namespace sd {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
// weightsDepth [kH, kW, iC, mC] always
|
||||
// weightsPoint [1, 1, iC*mC, oC] always
|
||||
// weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
// weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
|
||||
// bias [oC], oC = iC*mC if weightsPoint=nullptr
|
||||
// output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||
|
||||
|
@ -545,7 +585,7 @@ namespace sd {
|
|||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weightsDepth->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
NDArray* outputDepth = output;
|
||||
|
@ -553,11 +593,11 @@ namespace sd {
|
|||
outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext());
|
||||
|
||||
// ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- //
|
||||
ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW);
|
||||
ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat);
|
||||
|
||||
// ----- perform pointwise convolution (oH = iH, oW = iW) ----- //
|
||||
if (weightsPoint) {
|
||||
ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW); // in this case oH=iH, oW=iW
|
||||
ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat); // in this case oH=iH, oW=iW
|
||||
delete outputDepth;
|
||||
}
|
||||
}
|
||||
|
@ -1772,20 +1812,20 @@ namespace sd {
|
|||
|
||||
|
||||
|
||||
void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||
void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||
void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||
void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||
void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||
void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
||||
BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES);
|
||||
|
|
|
@ -217,10 +217,10 @@ void ConvolutionUtils::col2vol(sd::graph::Context& block, const NDArray& col, ND
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
// weights [kH, kW, iC, oC] always
|
||||
// weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
// bias [oC]
|
||||
// output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||
|
||||
|
@ -237,7 +237,7 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
|
@ -248,6 +248,14 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr
|
|||
else
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC
|
||||
|
||||
std::vector<int> wAxes;
|
||||
if(0 == wFormat)
|
||||
wAxes = {0, 1, 2};
|
||||
else if(1 == wFormat)
|
||||
wAxes = {2, 3, 1};
|
||||
else
|
||||
wAxes = {1, 2, 3};
|
||||
|
||||
NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext());
|
||||
NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW}
|
||||
NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext());
|
||||
|
@ -255,7 +263,7 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr
|
|||
//----- calculation of output -----//
|
||||
auto ctx = block.launchContext();
|
||||
helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, {0,1,2}, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC]
|
||||
MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC]
|
||||
|
||||
//----- assign outTemp to output -----//
|
||||
if(isNCHW) {
|
||||
|
@ -275,16 +283,16 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr
|
|||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||
void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
// weights [kH, kW, iC, mC] always
|
||||
// weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
// bias [oC] = iC*mC
|
||||
// output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
|
||||
|
||||
|
@ -301,23 +309,30 @@ static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, co
|
|||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
std::vector<std::vector<Nd4jLong>> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW]
|
||||
std::vector<std::vector<Nd4jLong>> modifOutput;
|
||||
std::vector<std::vector<Nd4jLong>> modifOutput, modifWeights;
|
||||
std::vector<Nd4jLong> outReShape;
|
||||
|
||||
if(!isNCHW) {
|
||||
outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC]
|
||||
modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
|
||||
}
|
||||
else {
|
||||
outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW]
|
||||
modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
|
||||
}
|
||||
|
||||
if(0 == wFormat)
|
||||
modifWeights = {{2,0,1,3},{iC,kH*kW,mC}};
|
||||
else if(1 == wFormat)
|
||||
modifWeights = {{1,2,3,0},{iC,kH*kW,mC}};
|
||||
else
|
||||
modifWeights = {{3,1,2,0},{iC,kH*kW,mC}};
|
||||
|
||||
if(paddingMode == 1) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
|
@ -325,7 +340,7 @@ static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, co
|
|||
NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false);
|
||||
|
||||
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
|
||||
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
|
||||
|
||||
if(bias)
|
||||
// output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
|
||||
|
@ -336,17 +351,17 @@ static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, co
|
|||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||
void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
// weightsDepth [kH, kW, iC, mC] always
|
||||
// weightsPoint [1, 1, iC*mC, oC] always
|
||||
// weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
// weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
|
||||
// bias [oC], oC = iC*mC if weightsPoint=nullptr
|
||||
// output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||
|
||||
|
@ -363,7 +378,7 @@ static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDAr
|
|||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weightsDepth->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
NDArray* outputDepth = output;
|
||||
|
@ -371,18 +386,18 @@ static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDAr
|
|||
outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext());
|
||||
|
||||
// ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- //
|
||||
ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW);
|
||||
ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat);
|
||||
|
||||
// ----- perform pointwise convolution (oH = iH, oW = iW) ----- //
|
||||
if (weightsPoint) {
|
||||
ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW); // in this case oH=iH, oW=iW
|
||||
ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat); // in this case oH=iH, oW=iW
|
||||
delete outputDepth;
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||
void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1176,15 +1191,15 @@ void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, const NDArray& inp
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
// weights [kH, kW, iC, oC] always
|
||||
// weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
// bias [oC]
|
||||
// gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
|
||||
// gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
// gradW [kH, kW, iC, oC] always
|
||||
// gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
// gradB [oC]
|
||||
|
||||
// kH filter(kernel) height
|
||||
|
@ -1200,7 +1215,7 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
|
@ -1214,13 +1229,27 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA
|
|||
gradOaxesForDot = {0, 2, 3}; // bS, oH, oW
|
||||
}
|
||||
|
||||
std::vector<int> wPermut, colPermut;
|
||||
if(0 == wFormat) {
|
||||
wPermut = {2, 0, 1, 3};
|
||||
colPermut = {2, 3, 1, 0, 4, 5};
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
wPermut = {1, 2, 3, 0};
|
||||
colPermut = {1, 2, 3, 0, 4, 5};
|
||||
}
|
||||
else {
|
||||
wPermut = {3, 1, 2, 0};
|
||||
colPermut = {2, 3, 1, 0, 4, 5};
|
||||
}
|
||||
|
||||
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
||||
|
||||
// ----- calculation of gradW ----- //
|
||||
if(gradW) {
|
||||
auto ctx = block.launchContext();
|
||||
helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, {2, 0, 1, 3}); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC]
|
||||
sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC]
|
||||
}
|
||||
|
||||
// ----- calculation of gradB ----- //
|
||||
|
@ -1234,7 +1263,10 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA
|
|||
}
|
||||
|
||||
//----- calculation of gradI -----//
|
||||
sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2, 3, 1, 0, 4, 5}); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
|
||||
// [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
|
||||
// [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW]
|
||||
// [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
|
||||
sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
|
||||
|
||||
helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
|
||||
|
||||
|
@ -1245,20 +1277,20 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA
|
|||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||
void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||
// weights [kH, kW, iC, mC] always
|
||||
// weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
// bias [oC] = [iC*mC]
|
||||
// gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||
// gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
|
||||
// gradW [kH, kW, iC, mC] always
|
||||
// gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
// gradB [oC]
|
||||
|
||||
// kH filter(kernel) height
|
||||
|
@ -1274,11 +1306,11 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
|
|||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
std::vector<std::vector<Nd4jLong>> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW]
|
||||
std::vector<std::vector<Nd4jLong>> modifGradO1, modifGradO2;
|
||||
std::vector<std::vector<Nd4jLong>> modifGradO1, modifGradO2, modifWeights;
|
||||
std::vector<Nd4jLong> gradOreShape;
|
||||
|
||||
if(!isNCHW) {
|
||||
|
@ -1294,6 +1326,13 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
|
|||
modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
|
||||
}
|
||||
|
||||
if(0 == wFormat)
|
||||
modifWeights = {{2,0,1,3},{iC,kH*kW,mC}};
|
||||
else if(1 == wFormat)
|
||||
modifWeights = {{1,2,3,0},{iC,kH*kW,mC}};
|
||||
else
|
||||
modifWeights = {{3,1,2,0},{iC,kH*kW,mC}};
|
||||
|
||||
if(paddingMode == 1) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
|
@ -1303,7 +1342,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
|
|||
// ----- calculation of gradW and gradB ----- //
|
||||
|
||||
helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, {{2,0,1,3},{iC,kH*kW,mC}}); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC]
|
||||
sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC]
|
||||
|
||||
// ----- calculation of gradB ----- //
|
||||
if(gradB) {
|
||||
|
@ -1316,7 +1355,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
|
|||
}
|
||||
|
||||
//----- calculation of gradI -----//
|
||||
sd::MmulHelper::tensorDot(weights, gradO, &columns, {{2,0,1,3},{iC,kH*kW,mC}}, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW]
|
||||
sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW]
|
||||
helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
|
||||
|
||||
if(!isNCHW) {
|
||||
|
@ -1326,8 +1365,8 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
|
|||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||
void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -102,7 +102,7 @@ PLATFORM_IMPL(avgpool2d_bp, ENGINE_CUDA) {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
|
||||
|
|
|
@ -54,7 +54,7 @@ PLATFORM_IMPL(avgpool3dnew, ENGINE_CUDA) {
|
|||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
|
@ -108,7 +108,7 @@ PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CUDA) {
|
|||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
|
|
|
@ -34,22 +34,25 @@ static void conv2dCUDNN(const LaunchContext* context,
|
|||
const int sH, const int sW,
|
||||
const int pH, const int pW,
|
||||
const int dH, const int dW,
|
||||
const int paddingMode, const bool isNCHW) {
|
||||
const int paddingMode, const bool isNCHW, const int wFormat) {
|
||||
|
||||
// cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC}
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||
if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", err);
|
||||
|
||||
cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||
cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||
cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC);
|
||||
|
||||
// input descriptor
|
||||
cudnnTensorDescriptor_t x;
|
||||
cudnnCreateTensorDescriptor(&x);
|
||||
if(input->ews() == 1)
|
||||
if(input->ews() == 1 && input->ordering() == 'c')
|
||||
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
|
||||
|
@ -58,13 +61,13 @@ static void conv2dCUDNN(const LaunchContext* context,
|
|||
// weights descriptor
|
||||
cudnnFilterDescriptor_t w;
|
||||
cudnnCreateFilterDescriptor(&w);
|
||||
err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), CUDNN_TENSOR_NCHW, oC, iC, kH, kW);
|
||||
err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), formatW, oC, iC, kH, kW);
|
||||
if(err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetFilter4dDescriptor failed", err);
|
||||
|
||||
// output descriptor
|
||||
cudnnTensorDescriptor_t z;
|
||||
cudnnCreateTensorDescriptor(&z);
|
||||
if(output->ews() == 1)
|
||||
if(output->ews() == 1 && output->ordering() == 'c')
|
||||
err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1));
|
||||
|
@ -104,10 +107,10 @@ static void conv2dCUDNN(const LaunchContext* context,
|
|||
|
||||
// add bias if it is present
|
||||
if (bias != nullptr) {
|
||||
|
||||
cudnnTensorDescriptor_t b;
|
||||
cudnnCreateTensorDescriptor(&b);
|
||||
err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf());
|
||||
// err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf());
|
||||
err = cudnnSetTensor4dDescriptor(b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1);
|
||||
if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err);
|
||||
err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer());
|
||||
if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnAddTensor bias failed", err);
|
||||
|
@ -131,22 +134,23 @@ static void conv2dBpCUDNN(const LaunchContext* context,
|
|||
const int sH, const int sW,
|
||||
const int pH, const int pW,
|
||||
const int dH, const int dW,
|
||||
const int paddingMode, const bool isNCHW) {
|
||||
const int paddingMode, const bool isNCHW, const int wFormat) {
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||
if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: can't set stream for cuDNN", err);
|
||||
|
||||
cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||
cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||
cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC);
|
||||
|
||||
// input descriptor
|
||||
cudnnTensorDescriptor_t x;
|
||||
cudnnCreateTensorDescriptor(&x);
|
||||
if(input->ews() == 1)
|
||||
if(input->ews() == 1 && input->ordering() == 'c')
|
||||
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
|
||||
|
@ -155,7 +159,7 @@ static void conv2dBpCUDNN(const LaunchContext* context,
|
|||
// gradO descriptor
|
||||
cudnnTensorDescriptor_t dz;
|
||||
cudnnCreateTensorDescriptor(&dz);
|
||||
if(gradO->ews() == 1)
|
||||
if(gradO->ews() == 1 && gradO->ordering() == 'c')
|
||||
err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1));
|
||||
|
@ -164,7 +168,7 @@ static void conv2dBpCUDNN(const LaunchContext* context,
|
|||
// gradI descriptor
|
||||
cudnnTensorDescriptor_t dx;
|
||||
cudnnCreateTensorDescriptor(&dx);
|
||||
if(gradI->ews() == 1)
|
||||
if(gradI->ews() == 1 && gradI->ordering() == 'c')
|
||||
err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1));
|
||||
|
@ -173,7 +177,7 @@ static void conv2dBpCUDNN(const LaunchContext* context,
|
|||
// gradW descriptor
|
||||
cudnnFilterDescriptor_t dw;
|
||||
cudnnCreateFilterDescriptor(&dw);
|
||||
err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), CUDNN_TENSOR_NCHW, oC, iC, kH, kW);
|
||||
err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), formatW, oC, iC, kH, kW);
|
||||
if(err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err);
|
||||
|
||||
// description of convolution
|
||||
|
@ -220,7 +224,8 @@ static void conv2dBpCUDNN(const LaunchContext* context,
|
|||
if(gradB != nullptr) {
|
||||
cudnnTensorDescriptor_t db;
|
||||
cudnnCreateTensorDescriptor(&db);
|
||||
err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf());
|
||||
// err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf());
|
||||
err = cudnnSetTensor4dDescriptor(db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1);
|
||||
if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err);
|
||||
|
||||
err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer());
|
||||
|
@ -251,7 +256,7 @@ static void conv2dBpCUDNN(const LaunchContext* context,
|
|||
PLATFORM_IMPL(conv2d, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||
|
@ -263,7 +268,8 @@ PLATFORM_IMPL(conv2d, ENGINE_CUDA) {
|
|||
int dH = INT_ARG(6); // dilations height
|
||||
int dW = INT_ARG(7); // dilations width
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
|
||||
|
||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
|
||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
|
||||
|
@ -273,31 +279,35 @@ PLATFORM_IMPL(conv2d, ENGINE_CUDA) {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias) {
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
REQUIRE_TRUE((bias->rankOf() == 1 && bias->strideAt(0) == 1) || (bias->rankOf() == 2 && bias->sizeAt(0) == 1 && bias->strideAt(1) == 1) || (bias->rankOf() == 2 && bias->sizeAt(1) == 1 && bias->strideAt(0) == 1), 0, "CUSTOM CONV2D CUDNN OP: bias array should be contiguous in memory !");
|
||||
}
|
||||
|
||||
NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC}
|
||||
newWeights->assign(weights->permute({3,2,0,1})); // permute weights (kH, kW, iC, oC --> oC, iC, kH, kW)
|
||||
NDArray* newWeights = weights; // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC}
|
||||
if(0 == wFormat) {
|
||||
newWeights = new NDArray(weights->ordering(), isNCHW ? std::vector<Nd4jLong>({oC, iC, kH, kW}) : std::vector<Nd4jLong>({oC, kH, kW, iC}), weights->dataType(), weights->getContext());
|
||||
newWeights->assign(weights->permute(isNCHW ? std::vector<int>({3,2,0,1}) : std::vector<int>({3,0,1,2}))); // (kH, kW, iC, oC --> oC, iC, kH, kW) or (kH, kW, iC, oC --> oC, kH, kW, iC)
|
||||
}
|
||||
|
||||
NDArray* newInput = input;
|
||||
NDArray* newGradI = nullptr;
|
||||
if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings
|
||||
checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW);
|
||||
|
||||
conv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW, paddingMode, isNCHW);
|
||||
conv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW, paddingMode, isNCHW, wFormat);
|
||||
|
||||
if(newInput != input)
|
||||
delete newInput;
|
||||
|
||||
delete newWeights;
|
||||
if(0 == wFormat)
|
||||
delete newWeights;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -322,12 +332,12 @@ PLATFORM_CHECK(conv2d, ENGINE_CUDA) {
|
|||
PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
int kH = INT_ARG(0); // filter(kernel) height
|
||||
|
@ -340,6 +350,7 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
||||
|
@ -347,7 +358,7 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
|
||||
|
@ -355,26 +366,30 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) {
|
|||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
NDArray* newGradW = new NDArray(gradW->ordering(), {oC, iC, kH, kW}, gradW->dataType(), gradW->getContext()); // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC}
|
||||
NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kH, kW}, weights->dataType(), weights->getContext());
|
||||
|
||||
newWeights->assign(weights->permute({3,2,0,1})); // permute weights (kH, kW, iC, oC --> oC, iC, kH, kW)
|
||||
NDArray *newWeights = weights, *newGradW = gradW; // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC}
|
||||
if(0 == wFormat) {
|
||||
newGradW = new NDArray(gradW->ordering(), isNCHW ? std::vector<Nd4jLong>({oC, iC, kH, kW}) : std::vector<Nd4jLong>({oC, kH, kW, iC}), gradW->dataType(), gradW->getContext());
|
||||
newWeights = new NDArray(weights->ordering(), isNCHW ? std::vector<Nd4jLong>({oC, iC, kH, kW}) : std::vector<Nd4jLong>({oC, kH, kW, iC}), weights->dataType(), weights->getContext());
|
||||
newWeights->assign(weights->permute(isNCHW ? std::vector<int>({3,2,0,1}) : std::vector<int>({3,0,1,2}))); // (kH, kW, iC, oC --> oC, iC, kH, kW) or (kH, kW, iC, oC --> oC, kH, kW, iC)
|
||||
}
|
||||
|
||||
NDArray* newInput = input;
|
||||
NDArray* newGradI = gradI;
|
||||
if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings
|
||||
checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW);
|
||||
|
||||
conv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW);
|
||||
conv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW,wFormat);
|
||||
|
||||
newGradW->permutei({2,3,1,0}); // [oC, iC, kH, kW] -> [kH, kW, iC, oC]
|
||||
gradW->assign(newGradW);
|
||||
if(0 == wFormat) {
|
||||
newGradW->permutei(isNCHW ? std::vector<int>({2,3,1,0}) : std::vector<int>({1,2,3,0})); // (oC, iC, kH, kW --> kH, kW, iC, oC) or (oC, kH, kW, iC --> kH, kW, iC, oC)
|
||||
gradW->assign(newGradW);
|
||||
}
|
||||
|
||||
if(newInput != input) {
|
||||
|
||||
|
@ -387,8 +402,10 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) {
|
|||
delete newGradI;
|
||||
}
|
||||
|
||||
delete newWeights;
|
||||
delete newGradW;
|
||||
if(0 == wFormat) {
|
||||
delete newWeights;
|
||||
delete newGradW;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -34,13 +34,15 @@ static void conv3dCUDNN(const LaunchContext* context,
|
|||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW,
|
||||
const int paddingMode, const bool isNCDHW) {
|
||||
const int paddingMode, const bool isNCDHW, const int wFormat) {
|
||||
|
||||
// cudnn support only one format for weights {oC,iC,kD,kH,kW}
|
||||
|
||||
const int numDims = 5;
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||
|
@ -53,7 +55,7 @@ static void conv3dCUDNN(const LaunchContext* context,
|
|||
const std::vector<int> xShape = {bS, iC, iD, iH, iW};
|
||||
const std::vector<int> zShape = {bS, oC, oD, oH, oW};
|
||||
const std::vector<int> wShape = {oC, iC, kD, kH, kW};
|
||||
const std::vector<int> bShape = {1, (isNCDHW ? oC : 1), 1, 1, (isNCDHW ? 1 : oC)};
|
||||
const std::vector<int> bShape = {1, oC, 1, 1, 1}; // {1, (isNCDHW ? oC : 1), 1, 1, (isNCDHW ? 1 : oC)};
|
||||
|
||||
const std::vector<int> xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)};
|
||||
const std::vector<int> zStrides = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)};
|
||||
|
@ -120,7 +122,7 @@ static void conv3dCUDNN(const LaunchContext* context,
|
|||
|
||||
cudnnTensorDescriptor_t b;
|
||||
cudnnCreateTensorDescriptor(&b);
|
||||
err = cudnnSetTensorNdDescriptorEx(b, format, cudnnDataType(bias->dataType()), numDims, bShape.data());
|
||||
err = cudnnSetTensorNdDescriptorEx(b, /*format*/CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), numDims, bShape.data());
|
||||
if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor for bias failed", err);
|
||||
err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer());
|
||||
if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnAddTensor bias failed", err);
|
||||
|
@ -144,13 +146,15 @@ static void conv3dBpCUDNN(const LaunchContext* context,
|
|||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW,
|
||||
const int paddingMode, const bool isNCDHW) {
|
||||
const int paddingMode, const bool isNCDHW, const int wFormat) {
|
||||
|
||||
// cudnn supports only two formats {oC,iC,kD,kH,kW} and {oC,kD,kH,kW,iC} for weights/gradW
|
||||
|
||||
const int numDims = 5;
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||
|
@ -170,6 +174,7 @@ static void conv3dBpCUDNN(const LaunchContext* context,
|
|||
const std::vector<int> dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)};
|
||||
|
||||
cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||
cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC);
|
||||
|
||||
// input descriptor
|
||||
cudnnTensorDescriptor_t x;
|
||||
|
@ -201,7 +206,7 @@ static void conv3dBpCUDNN(const LaunchContext* context,
|
|||
// gradW descriptor
|
||||
cudnnFilterDescriptor_t dw;
|
||||
cudnnCreateFilterDescriptor(&dw);
|
||||
err = cudnnSetFilterNdDescriptor(dw, cudnnDataType(gradW->dataType()), CUDNN_TENSOR_NCHW, numDims, wShape.data());
|
||||
err = cudnnSetFilterNdDescriptor(dw, cudnnDataType(gradW->dataType()), formatW, numDims, wShape.data());
|
||||
if(err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetFilterNdDescriptor failed", err);
|
||||
|
||||
// description of convolution
|
||||
|
@ -280,7 +285,7 @@ static void conv3dBpCUDNN(const LaunchContext* context,
|
|||
PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||
|
||||
|
@ -301,34 +306,39 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) {
|
|||
int dW = INT_ARG(11); // dilations width
|
||||
int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
|
||||
|
||||
REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV3D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV3D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC}
|
||||
newWeights->assign(weights->permute({4,3,0,1,2})); // permute weights (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW)
|
||||
NDArray* newWeights = weights; // cudnn support only one format {oC,iC,kD,kH,kW}
|
||||
if(1 != wFormat) {
|
||||
newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext());
|
||||
newWeights->assign(weights->permute(0 == wFormat ? std::vector<int>({4,3,0,1,2}) : std::vector<int>({0,4,1,2,3}))); // kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW or oC, kD, kH, kW, iC --> oC, iC, kD, kH, kW
|
||||
}
|
||||
|
||||
NDArray* newInput = input;
|
||||
NDArray* newGradI = nullptr;
|
||||
if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings
|
||||
checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW);
|
||||
|
||||
conv3dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW, paddingMode, isNCDHW);
|
||||
conv3dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW, paddingMode, isNCDHW, wFormat);
|
||||
|
||||
if(newInput != input)
|
||||
delete newInput;
|
||||
|
||||
delete newWeights;
|
||||
if(1 != wFormat)
|
||||
delete newWeights;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -337,7 +347,7 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) {
|
|||
PLATFORM_CHECK(conv3dnew, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||
|
@ -353,12 +363,12 @@ PLATFORM_CHECK(conv3dnew, ENGINE_CUDA) {
|
|||
PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "CONV3D_BP CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
|
@ -379,10 +389,11 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) {
|
|||
int dW = INT_ARG(11); // dilations width
|
||||
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
int trueoD, trueoH, trueoW; // true output depth/height/width
|
||||
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
|
||||
|
@ -390,7 +401,7 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) {
|
|||
REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D_BP CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV3D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(gradW->isSameShape(expectedWeightsShape), 0, "CONV3D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
|
@ -398,20 +409,25 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) {
|
|||
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode);
|
||||
|
||||
NDArray* newGradW = new NDArray(gradW->ordering(), {oC, iC, kD, kH, kW}, gradW->dataType(), gradW->getContext()); // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC}
|
||||
NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext());
|
||||
|
||||
newWeights->assign(weights->permute({4,3,0,1,2})); // permute weights (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW)
|
||||
NDArray *newWeights = weights, *newGradW = gradW; // cudnn support only two formats {oC,iC,kD,kH,kW} and {oC,kD,kH,kW,iC}
|
||||
if(0 == wFormat) {
|
||||
newGradW = new NDArray(gradW->ordering(), isNCDHW ? std::vector<Nd4jLong>({oC, iC, kD, kH, kW}) : std::vector<Nd4jLong>({oC, kD, kH, kW, iC}), gradW->dataType(), gradW->getContext());
|
||||
newWeights = new NDArray(weights->ordering(), isNCDHW ? std::vector<Nd4jLong>({oC, iC, kD, kH, kW}) : std::vector<Nd4jLong>({oC, kD, kH, kW, iC}), weights->dataType(), weights->getContext());
|
||||
newWeights->assign(weights->permute(isNCDHW ? std::vector<int>({4,3,0,1,2}) : std::vector<int>({4,0,1,2,3}))); // (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW) or (kD, kH, kW, iC, oC --> oC, kD, kH, kW, iC)
|
||||
}
|
||||
|
||||
NDArray* newInput = input;
|
||||
NDArray* newGradI = gradI;
|
||||
if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings
|
||||
checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW);
|
||||
|
||||
conv3dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW,paddingMode,isNCDHW);
|
||||
conv3dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW,paddingMode,isNCDHW,wFormat);
|
||||
|
||||
if(0 == wFormat) {
|
||||
newGradW->permutei(isNCDHW ? std::vector<int>({2,3,4,1,0}) : std::vector<int>({1,2,3,4,0})); // (oC, iC, kD, kH, kW --> kD, kH, kW, iC, oC) or (oC, kD, kH, kW, iC --> kD, kH, kW, iC, oC)
|
||||
gradW->assign(newGradW);
|
||||
}
|
||||
|
||||
newGradW->permutei({2,3,4,1,0}); // [oC, iC, kD, kH, kW] -> [kD, kH, kW, iC, oC]
|
||||
gradW->assign(newGradW);
|
||||
|
||||
if(newInput != input) {
|
||||
|
||||
|
@ -424,8 +440,10 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) {
|
|||
delete newGradI;
|
||||
}
|
||||
|
||||
delete newWeights;
|
||||
delete newGradW;
|
||||
if(0 == wFormat) {
|
||||
delete newWeights;
|
||||
delete newGradW;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -433,7 +451,7 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) {
|
|||
PLATFORM_CHECK(conv3dnew_bp, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
|
|
|
@ -124,7 +124,7 @@ void pooling2dCUDNN(const LaunchContext* context,
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||
|
@ -135,7 +135,7 @@ void pooling2dCUDNN(const LaunchContext* context,
|
|||
// input descriptor
|
||||
cudnnTensorDescriptor_t x;
|
||||
cudnnCreateTensorDescriptor(&x);
|
||||
if(input->ews() == 1)
|
||||
if(input->ews() == 1 && input->ordering() == 'c')
|
||||
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
|
||||
|
@ -144,7 +144,7 @@ void pooling2dCUDNN(const LaunchContext* context,
|
|||
// output descriptor
|
||||
cudnnTensorDescriptor_t z;
|
||||
cudnnCreateTensorDescriptor(&z);
|
||||
if(output->ews() == 1)
|
||||
if(output->ews() == 1 && output->ordering() == 'c')
|
||||
err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1));
|
||||
|
@ -187,7 +187,7 @@ void pooling2dBpCUDNN(const LaunchContext* context,
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||
|
@ -198,7 +198,7 @@ void pooling2dBpCUDNN(const LaunchContext* context,
|
|||
// input and gradI descriptor
|
||||
cudnnTensorDescriptor_t x;
|
||||
cudnnCreateTensorDescriptor(&x);
|
||||
if(input->ews() == 1)
|
||||
if(input->ews() == 1 && input->ordering() == 'c')
|
||||
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
|
||||
|
@ -207,7 +207,7 @@ void pooling2dBpCUDNN(const LaunchContext* context,
|
|||
// gradO descriptor
|
||||
cudnnTensorDescriptor_t dz;
|
||||
cudnnCreateTensorDescriptor(&dz);
|
||||
if(gradO->ews() == 1)
|
||||
if(gradO->ews() == 1 && gradO->ordering() == 'c')
|
||||
err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1));
|
||||
|
@ -255,7 +255,7 @@ void pooling3dCUDNN(const LaunchContext* context,
|
|||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
const int pSizes[] = {pD, pH, pW};
|
||||
const int sSizes[] = {sD, sH, sW};
|
||||
|
@ -272,7 +272,7 @@ void pooling3dCUDNN(const LaunchContext* context,
|
|||
// input descriptor
|
||||
cudnnTensorDescriptor_t x;
|
||||
cudnnCreateTensorDescriptor(&x);
|
||||
if(input->ews() == 1)
|
||||
if(input->ews() == 1 && input->ordering() == 'c')
|
||||
err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape);
|
||||
else
|
||||
err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides);
|
||||
|
@ -281,7 +281,7 @@ void pooling3dCUDNN(const LaunchContext* context,
|
|||
// output descriptor
|
||||
cudnnTensorDescriptor_t z;
|
||||
cudnnCreateTensorDescriptor(&z);
|
||||
if(output->ews() == 1)
|
||||
if(output->ews() == 1 && output->ordering() == 'c')
|
||||
err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape);
|
||||
else
|
||||
err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape, zStrides);
|
||||
|
@ -330,7 +330,7 @@ void pooling3dBpCUDNN(const LaunchContext* context,
|
|||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
const int pSizes[] = {pD, pH, pW};
|
||||
const int sSizes[] = {sD, sH, sW};
|
||||
|
@ -347,7 +347,7 @@ void pooling3dBpCUDNN(const LaunchContext* context,
|
|||
// input and gradI descriptor
|
||||
cudnnTensorDescriptor_t x;
|
||||
cudnnCreateTensorDescriptor(&x);
|
||||
if(input->ews() == 1)
|
||||
if(input->ews() == 1 && input->ordering() == 'c')
|
||||
err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape);
|
||||
else
|
||||
err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides);
|
||||
|
@ -356,7 +356,7 @@ void pooling3dBpCUDNN(const LaunchContext* context,
|
|||
// gradO descriptor
|
||||
cudnnTensorDescriptor_t dz;
|
||||
cudnnCreateTensorDescriptor(&dz);
|
||||
if(gradO->ews() == 1)
|
||||
if(gradO->ews() == 1 && gradO->ordering() == 'c')
|
||||
err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape);
|
||||
else
|
||||
err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape, dzStrides);
|
||||
|
|
|
@ -39,14 +39,14 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context,
|
|||
// cudnn supports only following case: mC = 1, oC = iC (groupCount == iC)
|
||||
|
||||
// input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc
|
||||
// weights [iC, mC, kH, kW], mkl doesn't support this format, so we'll make permute
|
||||
// weights [iC, mC, kH, kW]
|
||||
// bias [oC], may be nullptr
|
||||
// output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc
|
||||
// oC = iC*mC
|
||||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weights->sizeAt(1);
|
||||
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||
|
@ -58,7 +58,7 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context,
|
|||
// input descriptor
|
||||
cudnnTensorDescriptor_t x;
|
||||
cudnnCreateTensorDescriptor(&x);
|
||||
if(input->ews() == 1)
|
||||
if(input->ews() == 1 && input->ordering() == 'c')
|
||||
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
|
||||
|
@ -73,7 +73,7 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context,
|
|||
// output descriptor
|
||||
cudnnTensorDescriptor_t z;
|
||||
cudnnCreateTensorDescriptor(&z);
|
||||
if(output->ews() == 1)
|
||||
if(output->ews() == 1 && output->ordering() == 'c')
|
||||
err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1));
|
||||
|
@ -117,7 +117,8 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context,
|
|||
|
||||
cudnnTensorDescriptor_t b;
|
||||
cudnnCreateTensorDescriptor(&b);
|
||||
err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf());
|
||||
// err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf());
|
||||
err = cudnnSetTensor4dDescriptor(b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1);
|
||||
if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err);
|
||||
err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer());
|
||||
if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnAddTensor bias failed", err);
|
||||
|
@ -146,14 +147,14 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context,
|
|||
// cudnn supports only following case: mC = 1, oC = iC (groupCount == iC)
|
||||
|
||||
// input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc
|
||||
// weights, gradW [iC, mC, kH, kW], mkl doesn't support this format, so we'll make permute
|
||||
// weights, gradW [iC, mC, kH, kW]
|
||||
// gradB [oC], may be nullptr
|
||||
// gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc
|
||||
// oC = iC*mC
|
||||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weights->sizeAt(1);
|
||||
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||
|
@ -165,7 +166,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context,
|
|||
// input descriptor
|
||||
cudnnTensorDescriptor_t x;
|
||||
cudnnCreateTensorDescriptor(&x);
|
||||
if(input->ews() == 1)
|
||||
if(input->ews() == 1 && input->ordering() == 'c')
|
||||
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
|
||||
|
@ -174,7 +175,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context,
|
|||
// gradO descriptor
|
||||
cudnnTensorDescriptor_t dz;
|
||||
cudnnCreateTensorDescriptor(&dz);
|
||||
if(gradO->ews() == 1)
|
||||
if(gradO->ews() == 1 && gradO->ordering() == 'c')
|
||||
err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1));
|
||||
|
@ -183,7 +184,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context,
|
|||
// gradI descriptor
|
||||
cudnnTensorDescriptor_t dx;
|
||||
cudnnCreateTensorDescriptor(&dx);
|
||||
if(gradI->ews() == 1)
|
||||
if(gradI->ews() == 1 && gradI->ordering() == 'c')
|
||||
err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1));
|
||||
|
@ -241,7 +242,8 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context,
|
|||
if(gradB != nullptr) {
|
||||
cudnnTensorDescriptor_t db;
|
||||
cudnnCreateTensorDescriptor(&db);
|
||||
err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf());
|
||||
// err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf());
|
||||
err = cudnnSetTensor4dDescriptor(db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1);
|
||||
if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err);
|
||||
|
||||
err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer());
|
||||
|
@ -272,7 +274,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context,
|
|||
PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
|
||||
|
@ -290,22 +292,31 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
|
||||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "DEPTHWISECONV2D CUDNN OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC);
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support format {oC, iC/groupCount, kH, kW}
|
||||
newWeights->assign(weights->permute({2,3,0,1})); // assign permuted weights (kH, kW, iC, mC --> iC, mC, kH, kW)
|
||||
std::vector<int> wPermut; // cudnn support format {oC, iC/groupCount, kH, kW} only, mC = 1, oC = iC (groupCount == iC) that is {iC, mC, kH, kW} in our case
|
||||
if(0 == wFormat)
|
||||
wPermut = {2,3,0,1}; // kH, kW, iC, mC -> iC, mC, kH, kW
|
||||
else if(1 == wFormat)
|
||||
wPermut = {1,0,2,3}; // mC, iC, kH, kW -> iC, mC, kH, kW
|
||||
else
|
||||
wPermut = {3,0,1,2}; // mC, kH, kW, iC -> iC, mC, kH, kW
|
||||
|
||||
NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext());
|
||||
newWeights->assign(weights->permute(wPermut));
|
||||
|
||||
NDArray* newInput = input;
|
||||
NDArray* newGradI = nullptr;
|
||||
|
@ -326,12 +337,13 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) {
|
|||
PLATFORM_CHECK(depthwise_conv2d, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC
|
||||
|
||||
const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL
|
||||
const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
|
||||
|
||||
const int mC = weights->sizeAt(3);
|
||||
const int mC = weights->sizeAt(0 == wFormat ? 3 : 0);
|
||||
|
||||
const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF;
|
||||
const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF;
|
||||
|
@ -344,12 +356,12 @@ PLATFORM_CHECK(depthwise_conv2d, ENGINE_CUDA) {
|
|||
PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||
|
@ -366,10 +378,11 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
|
||||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
int trueoH, trueoW; // correct output height, width
|
||||
|
@ -378,17 +391,30 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) {
|
|||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
std::vector<int> wPermut, gradWPermut; // cudnn support format {oC, iC/groupCount, kH, kW} only, mC = 1, oC = iC (groupCount == iC) that is {iC, mC, kH, kW}
|
||||
if(0 == wFormat) {
|
||||
wPermut = {2,3,0,1}; // kH, kW, iC, mC -> iC, mC, kH, kW
|
||||
gradWPermut = {2,3,0,1}; // iC, mC, kH, kW -> kH, kW, iC, mC
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
wPermut = {1,0,2,3}; // mC, iC, kH, kW -> iC, mC, kH, kW
|
||||
gradWPermut = {1,0,2,3}; // iC, mC, kH, kW -> mC, iC, kH, kW
|
||||
}
|
||||
else {
|
||||
wPermut = {3,0,1,2}; // mC, kH, kW, iC -> iC, mC, kH, kW
|
||||
gradWPermut = {1,2,3,0}; // iC, mC, kH, kW -> mC, kH, kW, iC
|
||||
}
|
||||
|
||||
NDArray* newGradW = new NDArray(gradW->ordering(), {iC, mC, kH, kW}, gradW->dataType(), gradW->getContext()); // cudnn support format {oC, iC/groupCount, kH, kW}
|
||||
NDArray* newGradW = new NDArray(gradW->ordering(), {iC, mC, kH, kW}, gradW->dataType(), gradW->getContext());
|
||||
NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext());
|
||||
|
||||
newWeights->assign(weights->permute({2,3,0,1})); // assign permuted weights (kH, kW, iC, mC --> iC, mC, kH, kW)
|
||||
newWeights->assign(weights->permute(wPermut));
|
||||
|
||||
NDArray* newInput = input;
|
||||
NDArray* newGradI = gradI;
|
||||
|
@ -397,7 +423,7 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) {
|
|||
|
||||
depthwiseConv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW);
|
||||
|
||||
newGradW->permutei({2,3,0,1}); // [iC, mC, kH, kW] -> [kH, kW, iC, mC]
|
||||
newGradW->permutei(gradWPermut);
|
||||
gradW->assign(newGradW);
|
||||
|
||||
if(newInput != input) {
|
||||
|
@ -420,14 +446,15 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) {
|
|||
PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL
|
||||
const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
|
||||
|
||||
const int mC = weights->sizeAt(3);
|
||||
const int mC = weights->sizeAt(0 == wFormat ? 3 : 0);
|
||||
|
||||
const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF;
|
||||
const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF;
|
||||
|
|
|
@ -98,7 +98,7 @@ PLATFORM_IMPL(maxpool2d_bp, ENGINE_CUDA) {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
|
||||
|
|
|
@ -54,7 +54,7 @@ PLATFORM_IMPL(maxpool3dnew, ENGINE_CUDA) {
|
|||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
|
@ -106,7 +106,7 @@ PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CUDA) {
|
|||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
|
|
|
@ -60,7 +60,7 @@ PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
if (paddingMode)
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
@ -105,7 +105,7 @@ PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
|
|
|
@ -61,7 +61,7 @@ PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) {
|
|||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
if(paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
@ -109,7 +109,7 @@ PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) {
|
|||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
|
|
|
@ -91,12 +91,12 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
||||
|
||||
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||
// z, output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
|
||||
|
||||
mkldnnUtils::setBlockStrides(z, xRank, z_user_md);
|
||||
mkldnnUtils::setBlockStrides(z, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -112,7 +112,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
// provide memory and check whether reorder is required
|
||||
|
||||
// x
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// z
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
|
||||
|
@ -207,19 +207,19 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
||||
|
||||
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||
|
||||
// dLdO
|
||||
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
|
||||
|
||||
mkldnnUtils::setBlockStrides(dLdO, xRank, dLdO_user_md);
|
||||
mkldnnUtils::setBlockStrides(dLdO, dLdO_user_md);
|
||||
|
||||
// dLdI
|
||||
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
|
||||
|
||||
mkldnnUtils::setBlockStrides(dLdI, xRank, dLdI_user_md);
|
||||
mkldnnUtils::setBlockStrides(dLdI, dLdI_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -239,10 +239,10 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
// provide memory and check whether reorder is required
|
||||
|
||||
// x
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// dLdO
|
||||
mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, args, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST);
|
||||
mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||
|
||||
// mean
|
||||
auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());
|
||||
|
|
|
@ -38,13 +38,13 @@ namespace platforms {
|
|||
static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||
const NDArray *bias, NDArray *output,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const int paddingMode, const int isNCHW) {
|
||||
const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
|
||||
// weights [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW]
|
||||
// mkl support weights in [oC, iC, kH, kW] format only
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||
|
||||
|
@ -53,8 +53,8 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||
dnnl::memory::dims dilation = { dH-1, dW-1};
|
||||
|
||||
auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||
auto xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kH, kW};
|
||||
|
@ -66,17 +66,29 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
mkldnnUtils::setBlockStrides(input, 4, x_user_md);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
uint i0, i1, i2, i3;
|
||||
if(0 == wFormat) {
|
||||
i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
|
||||
}
|
||||
else {
|
||||
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
}
|
||||
|
||||
// bias
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
|
@ -85,9 +97,8 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
|
||||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||
|
||||
mkldnnUtils::setBlockStrides(output, 4, z_user_md);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -103,10 +114,10 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// bias
|
||||
if(bias != nullptr) {
|
||||
|
@ -135,13 +146,13 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO,
|
||||
NDArray *gradI, NDArray *gradW, NDArray *gradB,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const int paddingMode, const int isNCHW) {
|
||||
const int paddingMode, const int isNCHW, const int wFormat) {
|
||||
|
||||
// weights/gradW [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW]
|
||||
// mkl support weights/gradW in [oC, iC, kH, kW] format only
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||
|
||||
|
@ -150,8 +161,8 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||
dnnl::memory::dims dilation = { dH-1, dW-1};
|
||||
|
||||
auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||
auto xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kH, kW};
|
||||
|
@ -163,36 +174,60 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
mkldnnUtils::setBlockStrides(input, 4, x_user_md);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
uint i0, i1, i2, i3;
|
||||
if(0 == wFormat) {
|
||||
i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
|
||||
}
|
||||
else {
|
||||
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
}
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||
mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
||||
|
||||
// gradW
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(2);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||
if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) {
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
uint i0, i1, i2, i3;
|
||||
if(0 == wFormat) {
|
||||
i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
|
||||
}
|
||||
else {
|
||||
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
|
||||
}
|
||||
|
||||
// gradB
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
|
@ -221,10 +256,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
|
@ -489,7 +524,7 @@ static void conv2dBpMKLDNN(sd::graph::Context &block,
|
|||
PLATFORM_IMPL(conv2d, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||
|
@ -500,24 +535,25 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) {
|
|||
int pW = INT_ARG(5); // paddings width
|
||||
int dH = INT_ARG(6); // dilations height
|
||||
int dW = INT_ARG(7); // dilations width
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
|
||||
|
||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
|
||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -536,12 +572,12 @@ PLATFORM_CHECK(conv2d, ENGINE_CPU) {
|
|||
PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC] always
|
||||
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
|
||||
|
||||
int kH = INT_ARG(0); // filter(kernel) height
|
||||
|
@ -554,10 +590,11 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
|
||||
|
@ -566,13 +603,13 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
|
|||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D_BP MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -40,13 +40,13 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW,
|
||||
const int paddingMode, const int isNCDHW) {
|
||||
const int paddingMode, const int isNCDHW, const int wFormat) {
|
||||
|
||||
// weights [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW]
|
||||
// mkl support weights in [oC, iC, kD, kH, kW] format only
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
// const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||
|
||||
|
@ -56,8 +56,8 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW};
|
||||
dnnl::memory::dims dilation = {dD-1, dH-1, dW-1};
|
||||
|
||||
auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
||||
auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||
|
@ -69,18 +69,30 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
mkldnnUtils::setBlockStrides(input, 5, x_user_md);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
uint i0, i1, i2, i3, i4;
|
||||
if(0 == wFormat) {
|
||||
i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
|
||||
}
|
||||
else {
|
||||
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
|
||||
}
|
||||
|
||||
// bias
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
|
@ -89,8 +101,8 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
|
||||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||
mkldnnUtils::setBlockStrides(output, 5, z_user_md);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -106,10 +118,10 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// bias
|
||||
if(bias != nullptr) {
|
||||
|
@ -140,13 +152,13 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW,
|
||||
const int paddingMode, const int isNCDHW) {
|
||||
const int paddingMode, const int isNCDHW, const int wFormat) {
|
||||
|
||||
// weights/gradW [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW]
|
||||
// mkl support weights/gradW in [oC, iC, kD, kH, kW] format only
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
// const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||
|
||||
|
@ -156,8 +168,8 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW};
|
||||
dnnl::memory::dims dilation = {dD-1, dH-1, dW-1};
|
||||
|
||||
auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
||||
auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||
|
@ -169,40 +181,64 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
mkldnnUtils::setBlockStrides(input, 5, x_user_md);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
uint i0, i1, i2, i3, i4;
|
||||
if(0 == wFormat) {
|
||||
i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
|
||||
}
|
||||
else {
|
||||
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
|
||||
}
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
||||
|
||||
mkldnnUtils::setBlockStrides(gradO, 5, gradO_user_md);
|
||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||
|
||||
mkldnnUtils::setBlockStrides(gradI, 5, gradI_user_md);
|
||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
||||
|
||||
// gradW
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||
if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) {
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
uint i0, i1, i2, i3, i4;
|
||||
if(0 == wFormat) {
|
||||
i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
|
||||
}
|
||||
else {
|
||||
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
|
||||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4);
|
||||
}
|
||||
|
||||
// gradB
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
|
@ -231,10 +267,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
|
@ -486,7 +522,7 @@ static void conv3dBpMKLDNN(sd::graph::Context &block,
|
|||
PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||
|
||||
|
@ -507,12 +543,13 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
|||
int dW = INT_ARG(11); // dilations width
|
||||
int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC]
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
@ -520,7 +557,7 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
|||
if (paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
|
||||
conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -538,12 +575,12 @@ PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
|
|||
PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
auto gradW = OUTPUT_NULLIFIED(1); // [kD, kH, kW, iC, oC] always
|
||||
auto gradW = OUTPUT_NULLIFIED(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
|
@ -564,10 +601,11 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
|||
int dW = INT_ARG(11); // dilations width
|
||||
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC]
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
if(paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
@ -576,26 +614,26 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
|||
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
|
||||
conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
return block.isUseMKLDNN() &&
|
||||
sd::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB});
|
||||
|
|
|
@ -34,19 +34,30 @@ namespace platforms {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const int paddingMode, const bool isNCHW) {
|
||||
const int paddingMode, const bool isNCHW, const int wFormat) {
|
||||
|
||||
// weights [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation
|
||||
// mkl supports weights format [oC, iC, kH, kW] only
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
|
||||
dnnl::memory::dims strides = { sH, sW };
|
||||
dnnl::memory::dims padding = { pH, pW };
|
||||
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||
dnnl::memory::dims dilation = { dH-1, dW-1 };
|
||||
|
||||
uint i0, i1, i2, i3;
|
||||
if(0 == wFormat) {
|
||||
i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
|
||||
}
|
||||
else {
|
||||
i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
|
||||
// input type
|
||||
dnnl::memory::data_type xType;
|
||||
if(input->dataType() == DataType::FLOAT32)
|
||||
|
@ -76,8 +87,8 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
else
|
||||
zType = dnnl::memory::data_type::s32;
|
||||
|
||||
dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||
dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kH, kW};
|
||||
|
@ -87,17 +98,17 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||
mkldnnUtils::setBlockStrides(input, 4, x_user_md);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
|
||||
// bias
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
|
@ -106,8 +117,8 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
|
||||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
|
||||
mkldnnUtils::setBlockStrides(output, 4, z_user_md);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -124,10 +135,10 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// bias
|
||||
if(bias != nullptr) {
|
||||
|
@ -156,19 +167,30 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const int paddingMode, const bool isNCHW) {
|
||||
const int paddingMode, const bool isNCHW, const int wFormat) {
|
||||
|
||||
// weights and gradW [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation
|
||||
// mkl supports weights/gradW in [oC, iC, kH, kW] format only
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
|
||||
dnnl::memory::dims strides = { sH, sW };
|
||||
dnnl::memory::dims padding = { pH, pW };
|
||||
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||
dnnl::memory::dims dilation = { dH-1, dW-1 };
|
||||
|
||||
uint i0, i1, i2, i3;
|
||||
if(0 == wFormat) {
|
||||
i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
|
||||
}
|
||||
else {
|
||||
i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
|
||||
// input type
|
||||
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// weights type
|
||||
|
@ -182,8 +204,8 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
|||
// gradB type
|
||||
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
||||
|
||||
dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||
dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kH, kW};
|
||||
|
@ -193,36 +215,36 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
|||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||
mkldnnUtils::setBlockStrides(input, 4, x_user_md);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||
mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||
mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
||||
|
||||
// gradW
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl);
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
|
||||
|
||||
// gradB
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
|
@ -251,10 +273,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
|
@ -311,7 +333,7 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
|||
PLATFORM_IMPL(deconv2d, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||
|
@ -327,14 +349,15 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) {
|
|||
int pW = INT_ARG(5); // paddings width
|
||||
int dH = INT_ARG(6); // dilations height
|
||||
int dW = INT_ARG(7); // dilations width
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC);
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
@ -344,7 +367,7 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) {
|
|||
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||
}
|
||||
|
||||
deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -377,12 +400,12 @@ PLATFORM_CHECK(deconv2d, ENGINE_CPU) {
|
|||
PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||
|
@ -398,18 +421,19 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) {
|
|||
int pW = INT_ARG(5); // paddings width
|
||||
int dH = INT_ARG(6); // dilations height
|
||||
int dW = INT_ARG(7); // dilations width
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||
|
||||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
|
@ -420,19 +444,19 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) {
|
|||
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||
}
|
||||
|
||||
deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(deconv2d_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
int dH = INT_ARG(6); // dilations height
|
||||
|
|
|
@ -34,7 +34,7 @@ namespace platforms {
|
|||
static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI,
|
||||
const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const bool isNCHW) {
|
||||
const bool isNCHW, const int wFormat) {
|
||||
|
||||
// gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format
|
||||
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC]
|
||||
|
@ -52,8 +52,8 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
|||
// gradI type
|
||||
dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
|
||||
dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||
dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kH, kW};
|
||||
|
@ -66,7 +66,7 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
|||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
|
||||
|
@ -75,13 +75,13 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
|||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||
mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||
mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -101,10 +101,10 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// gradO
|
||||
mkldnnUtils::loadDataToMklStream(gradO, engine, stream, args, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST);
|
||||
mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||
|
@ -128,10 +128,10 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
|||
PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
|
||||
|
||||
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI)
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
|
||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
|
||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
|
||||
|
@ -143,6 +143,7 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
|
||||
|
||||
const int rank = gradO->rankOf();
|
||||
|
||||
|
@ -188,7 +189,7 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
|
|||
// gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
// }
|
||||
|
||||
deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW);
|
||||
deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat);
|
||||
|
||||
// delete weights;
|
||||
|
||||
|
|
|
@ -35,19 +35,30 @@ namespace platforms {
|
|||
static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
||||
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW, const int dD, const int dH, const int dW,
|
||||
const bool isNCDHW) {
|
||||
const bool isNCDHW, const int wFormat) {
|
||||
|
||||
// weights [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation
|
||||
// mkl supports weights in [oC, iC, kD, kH, kW] only
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
|
||||
dnnl::memory::dims strides = { sD, sH, sW };
|
||||
dnnl::memory::dims padding = { pD, pH, pW };
|
||||
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
|
||||
|
||||
uint i0, i1, i2, i3, i4;
|
||||
if(0 == wFormat) {
|
||||
i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
else {
|
||||
i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
|
||||
// input type
|
||||
dnnl::memory::data_type xType;
|
||||
if(input->dataType() == DataType::FLOAT32)
|
||||
|
@ -77,8 +88,8 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
else
|
||||
zType = dnnl::memory::data_type::s32;
|
||||
|
||||
dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
||||
dnnl::memory::format_tag xFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||
|
@ -88,18 +99,18 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||
mkldnnUtils::setBlockStrides(input, 5, x_user_md);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
|
||||
|
||||
// bias
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
|
@ -108,8 +119,8 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
|
||||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
|
||||
mkldnnUtils::setBlockStrides(output, 5, z_user_md);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -126,10 +137,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// bias
|
||||
if(bias != nullptr) {
|
||||
|
@ -161,19 +172,30 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW,
|
||||
const bool isNCDHW) {
|
||||
const bool isNCDHW, const int wFormat) {
|
||||
|
||||
// weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation
|
||||
// mkl supports weights/gradW in [oC, iC, kD, kH, kW] format only
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
|
||||
dnnl::memory::dims strides = { sD, sH, sW };
|
||||
dnnl::memory::dims padding = { pD, pH, pW };
|
||||
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
|
||||
|
||||
uint i0, i1, i2, i3, i4;
|
||||
if(0 == wFormat) {
|
||||
i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
else {
|
||||
i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
|
||||
// input type
|
||||
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// weights type
|
||||
|
@ -187,8 +209,8 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
// gradB type
|
||||
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
||||
|
||||
dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
||||
dnnl::memory::format_tag xFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||
|
@ -198,38 +220,38 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||
mkldnnUtils::setBlockStrides(input, 5, x_user_md);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||
mkldnnUtils::setBlockStrides(gradO, 5, gradO_user_md);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||
mkldnnUtils::setBlockStrides(gradI, 5, gradI_user_md);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
||||
|
||||
// gradW
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl);
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(4);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2);
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
|
||||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4);
|
||||
|
||||
// gradB
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
|
@ -259,10 +281,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
|
@ -319,7 +341,7 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
PLATFORM_IMPL(deconv3d, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||
|
@ -341,12 +363,13 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) {
|
|||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC);
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
@ -356,7 +379,7 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) {
|
|||
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
}
|
||||
|
||||
deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW);
|
||||
deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -390,12 +413,12 @@ PLATFORM_CHECK(deconv3d, ENGINE_CPU) {
|
|||
PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
|
@ -416,17 +439,18 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) {
|
|||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||
|
||||
int trueoD, trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
|
@ -435,7 +459,7 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) {
|
|||
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW);
|
||||
deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -443,12 +467,12 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) {
|
|||
|
||||
PLATFORM_CHECK(deconv3d_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
int dD = INT_ARG(9); // dilations depth
|
||||
|
|
|
@ -35,19 +35,19 @@ namespace platforms {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const int paddingMode, const bool isNCHW) {
|
||||
const int paddingMode, const bool isNCHW, const int wFormat) {
|
||||
|
||||
// mkl supports only following case: mC = 1, oC = iC
|
||||
|
||||
// input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given
|
||||
// weights [kH, kW, iC, mC], mkl doesn't support this format, so we'll make permute
|
||||
// weights {iC, mC, 1, kH, kW}
|
||||
// bias [oC], may be nullptr
|
||||
// output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc
|
||||
// oC = iC*mC
|
||||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||
|
@ -57,6 +57,17 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||
dnnl::memory::dims dilation = { dH-1, dW-1};
|
||||
|
||||
uint i0, i1, i2, i3;
|
||||
if(0 == wFormat) {
|
||||
i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW]
|
||||
}
|
||||
else {
|
||||
i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW]
|
||||
}
|
||||
|
||||
// input type
|
||||
dnnl::memory::data_type xType;
|
||||
if(input->dataType() == DataType::FLOAT32)
|
||||
|
@ -86,8 +97,8 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
else
|
||||
zType = dnnl::memory::data_type::s32;
|
||||
|
||||
dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw;
|
||||
dnnl::memory::format_tag xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
dnnl::memory::dims wDims = {iC, mC, 1, kH, kW};
|
||||
|
@ -97,18 +108,18 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
|
||||
mkldnnUtils::setBlockStrides(input, 4, x_user_md);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
|
||||
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // permute
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); // permute
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = 0;
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(1);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3);
|
||||
|
||||
// bias
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
|
@ -117,8 +128,8 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
|
||||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat);
|
||||
mkldnnUtils::setBlockStrides(output, 4, z_user_md);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -135,10 +146,10 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// bias
|
||||
if(bias != nullptr) {
|
||||
|
@ -166,19 +177,19 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const int paddingMode, const bool isNCHW) {
|
||||
const int paddingMode, const bool isNCHW, const int wFormat) {
|
||||
|
||||
// mkl supports only following case: mC = 1, oC = iC
|
||||
|
||||
// input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given
|
||||
// weights, gradW [kH, kW, iC, mC], mkl doesn't support this format, so we'll make permute
|
||||
// weights/gradW {iC, mC, 1, kH, kW}
|
||||
// gradB [oC], may be nullptr
|
||||
// gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc
|
||||
// oC = iC*mC
|
||||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weights->sizeAt(indWmC);
|
||||
|
||||
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||
|
@ -188,6 +199,17 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||
dnnl::memory::dims dilation = { dH-1, dW-1};
|
||||
|
||||
uint i0, i1, i2, i3;
|
||||
if(0 == wFormat) {
|
||||
i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW]
|
||||
}
|
||||
else {
|
||||
i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW]
|
||||
}
|
||||
|
||||
// input type
|
||||
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
// weights type
|
||||
|
@ -201,8 +223,8 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
// gradB type
|
||||
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
||||
|
||||
dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw;
|
||||
dnnl::memory::format_tag xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw;
|
||||
|
||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||
dnnl::memory::dims wDims = {iC, mC, 1, kH, kW};
|
||||
|
@ -212,38 +234,38 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
|
||||
mkldnnUtils::setBlockStrides(input, 4, x_user_md);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
|
||||
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // permute
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); // permute
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = 0;
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(1);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3);
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFrmat);
|
||||
mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFrmat);
|
||||
mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
||||
|
||||
// gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||
// gradW
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl);
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(2); // permute
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3);
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); // permute
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = 0;
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i2);
|
||||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i3);
|
||||
|
||||
// gradB
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
|
@ -272,10 +294,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||
|
@ -332,7 +354,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
|
||||
|
@ -347,21 +369,22 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
|
||||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D MKL OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC);
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
depthwiseConv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
depthwiseConv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -394,12 +417,12 @@ PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) {
|
|||
PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
|
||||
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always
|
||||
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||
|
@ -416,10 +439,11 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
|
|||
int dW = INT_ARG(7); // dilations width
|
||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
|
||||
|
||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weights->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
int trueoH, trueoW; // correct output height, width
|
||||
|
@ -428,13 +452,13 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
|
|||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC};
|
||||
std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -443,12 +467,12 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
|
|||
PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
const DataType xType = input->dataType();
|
||||
|
|
|
@ -272,13 +272,13 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
|||
|
||||
// provide memory and check whether reorder is required
|
||||
// x
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, lstm_prim_desc.src_layer_desc(), DNNL_ARG_SRC_LAYER);
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]);
|
||||
|
||||
// wx
|
||||
mkldnnUtils::loadDataToMklStream(Wx, engine, stream, args, wx_user_md, lstm_prim_desc.weights_layer_desc(), DNNL_ARG_WEIGHTS_LAYER);
|
||||
mkldnnUtils::loadDataToMklStream(Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]);
|
||||
|
||||
// wr
|
||||
mkldnnUtils::loadDataToMklStream(Wr, engine, stream, args, wr_user_md, lstm_prim_desc.weights_iter_desc(), DNNL_ARG_WEIGHTS_ITER);
|
||||
mkldnnUtils::loadDataToMklStream(Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]);
|
||||
|
||||
// h
|
||||
auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer());
|
||||
|
@ -288,17 +288,17 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
|||
|
||||
// b
|
||||
if(b) {
|
||||
mkldnnUtils::loadDataToMklStream(b, engine, stream, args, b_user_md, lstm_prim_desc.bias_desc(), DNNL_ARG_BIAS);
|
||||
mkldnnUtils::loadDataToMklStream(b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]);
|
||||
}
|
||||
|
||||
// hI
|
||||
if(hI) {
|
||||
mkldnnUtils::loadDataToMklStream(hI, engine, stream, args, hI_user_md, lstm_prim_desc.src_iter_desc(), DNNL_ARG_SRC_ITER);
|
||||
mkldnnUtils::loadDataToMklStream(hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]);
|
||||
}
|
||||
|
||||
// cI
|
||||
if(cI) {
|
||||
mkldnnUtils::loadDataToMklStream(cI, engine, stream, args, cI_user_md, lstm_prim_desc.src_iter_c_desc(), DNNL_ARG_SRC_ITER_C);
|
||||
mkldnnUtils::loadDataToMklStream(cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]);
|
||||
}
|
||||
|
||||
bool hLReorder(false), cLReorder(false);
|
||||
|
|
|
@ -163,7 +163,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(xTR, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(xTR, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
/*
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->getBuffer());
|
||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
|
@ -173,7 +173,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
|
|||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
*/
|
||||
// y
|
||||
mkldnnUtils::loadDataToMklStream(yTR, engine, stream, args, y_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS);
|
||||
mkldnnUtils::loadDataToMklStream(yTR, engine, stream, y_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
/*
|
||||
auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->getBuffer());
|
||||
const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc();
|
||||
|
|
|
@ -60,7 +60,7 @@ PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
if (paddingMode)
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
@ -102,7 +102,7 @@ PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) {
|
|||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
|
|
|
@ -60,7 +60,7 @@ PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
|
|||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
if(paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
@ -107,7 +107,7 @@ PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
|
|||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
|
|
|
@ -56,25 +56,27 @@ dnnl::memory::format_tag getFormat(const int rank){
|
|||
}
|
||||
return dnnl::memory::format_tag::a; // 1 == dataSetRank
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
void setBlockStrides(const NDArray* array, const int rank, dnnl::memory::desc& mklMd){
|
||||
if (array->ews() != 1 || array->ordering() != 'c') {
|
||||
mklMd.data.format_kind = dnnl_blocked; // overrides format
|
||||
for (auto i = 0; i < rank; ++i) {
|
||||
mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i);
|
||||
}
|
||||
void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd){
|
||||
|
||||
if (array->ews() != 1 || array->ordering() != 'c') {
|
||||
mklMd.data.format_kind = dnnl_blocked; // overrides format
|
||||
for (auto i = 0; i < array->rankOf(); ++i) {
|
||||
mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
void loadDataToMklStream(const NDArray* array, dnnl::engine& engine, dnnl::stream& stream,
|
||||
std::unordered_map<int, dnnl::memory>& args, dnnl::memory::desc& user_md, dnnl::memory::desc primitive_md, int DNNL_ARG ){
|
||||
void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md,
|
||||
dnnl::memory& arg) {
|
||||
|
||||
auto user_mem = dnnl::memory(user_md, engine, array->getBuffer());
|
||||
const bool bReorder = primitive_md != user_mem.get_desc();
|
||||
auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem;
|
||||
if (bReorder)
|
||||
dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem);
|
||||
args[DNNL_ARG] = mkl_mem;
|
||||
auto user_mem = dnnl::memory(user_md, engine, array->getBuffer());
|
||||
const bool bReorder = primitive_md != user_mem.get_desc();
|
||||
auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem;
|
||||
if (bReorder)
|
||||
dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem);
|
||||
arg = mkl_mem;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -95,7 +97,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output,
|
|||
|
||||
if(rank == 4) { // 2d
|
||||
|
||||
ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
strides = { sH, sW };
|
||||
kernel = { kH, kW };
|
||||
|
@ -108,7 +110,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output,
|
|||
}
|
||||
else { // 3d
|
||||
|
||||
ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH);
|
||||
ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH);
|
||||
|
||||
strides = { sD, sH, sW };
|
||||
kernel = { kD, kH, kW };
|
||||
|
@ -162,7 +164,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output,
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// output
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
|
||||
|
@ -199,7 +201,7 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
|
|||
|
||||
if(rank == 4) { // 2d
|
||||
|
||||
ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
strides = { sH, sW };
|
||||
kernel = { kH, kW };
|
||||
|
@ -212,7 +214,7 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
|
|||
}
|
||||
else { // 3d
|
||||
|
||||
ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH);
|
||||
ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH);
|
||||
|
||||
strides = { sD, sH, sW };
|
||||
kernel = { kD, kH, kW };
|
||||
|
@ -280,7 +282,7 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
|
|||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
// gradO
|
||||
mkldnnUtils::loadDataToMklStream(gradO, engine, stream, args, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST);
|
||||
mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||
|
@ -291,7 +293,7 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
|
|||
if(mode == algorithm::pooling_max) {
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// z
|
||||
auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine);
|
||||
|
|
|
@ -131,7 +131,7 @@ namespace sd {
|
|||
* @param reference to memory descriptor
|
||||
* @return memory format
|
||||
*/
|
||||
void setBlockStrides(const NDArray* array, const int rank, dnnl::memory::desc& mklMd);
|
||||
void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd);
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
/**
|
||||
* This function load and reorder user memory to mkl
|
||||
|
@ -143,8 +143,8 @@ namespace sd {
|
|||
* @param primitive memory descriptor
|
||||
* @param dnnl arg activation enumerator
|
||||
*/
|
||||
void loadDataToMklStream(const NDArray* array, dnnl::engine& engine, dnnl::stream& stream,
|
||||
std::unordered_map<int, dnnl::memory>& args, dnnl::memory::desc& user_md, dnnl::memory::desc primitive_md, int DNNL_ARG);
|
||||
void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md,
|
||||
dnnl::memory& arg);
|
||||
|
||||
/**
|
||||
* Utility methods for MKLDNN
|
||||
|
|
|
@ -55,12 +55,12 @@ namespace sd {
|
|||
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
|
||||
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||
|
||||
// z
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format);
|
||||
mkldnnUtils::setBlockStrides(z, xRank, z_user_md);
|
||||
mkldnnUtils::setBlockStrides(z, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -80,7 +80,7 @@ namespace sd {
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// z
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
|
||||
|
@ -156,19 +156,19 @@ namespace sd {
|
|||
// x
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||
|
||||
// dLdx
|
||||
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(dLdx, xRank, dLdx_user_md);
|
||||
mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
|
||||
// todo if mkl does not support broadcast we can remove this
|
||||
format = mkldnnUtils::getFormat(dLdzRank);
|
||||
|
||||
// dLdz
|
||||
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(dLdz, dLdzRank, dLdz_user_md);
|
||||
mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -188,7 +188,7 @@ namespace sd {
|
|||
|
||||
// provide memory buffers and check whether reorder is required for forward
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, argsff, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_SRC]);
|
||||
|
||||
// dLdx
|
||||
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer());
|
||||
|
@ -200,7 +200,7 @@ namespace sd {
|
|||
argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
|
||||
argsbp[DNNL_ARG_DST] = dLdx_mkl_mem;
|
||||
// dLdz
|
||||
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, argsbp, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST);
|
||||
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), argsbp[DNNL_ARG_DIFF_DST]);
|
||||
|
||||
// run calculations forward
|
||||
dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff);
|
||||
|
|
|
@ -44,12 +44,12 @@ namespace sd {
|
|||
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||
|
||||
// z
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(z, xRank, z_user_md);
|
||||
mkldnnUtils::setBlockStrides(z, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -68,7 +68,7 @@ namespace sd {
|
|||
|
||||
// provide memory buffers and check whether reorder is required
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// z
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
|
||||
|
@ -132,17 +132,17 @@ namespace sd {
|
|||
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(x, xRank, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||
|
||||
// dLdz
|
||||
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(dLdz, xRank, dLdz_user_md);
|
||||
mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
|
||||
|
||||
// dLdx
|
||||
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(dLdx, xRank, dLdx_user_md);
|
||||
mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -162,10 +162,10 @@ namespace sd {
|
|||
|
||||
// provide memory buffers and check whether reorder is required for forward
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC);
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// dLdz
|
||||
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, args, dLdz_user_md, op_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST);
|
||||
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||
|
||||
// dLdx
|
||||
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer());
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue