Shyrma weights format (#329)

* - start to introduce additional weights formats into conv2d ops

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide weights format variety in backprop conv2d and deconv2d ops, testing and fixing bugs

Signed-off-by: Yurii <iuriish@yahoo.com>

* - forgot to recover kernels sizes in deconv2d_bp test

Signed-off-by: Yurii <iuriish@yahoo.com>

* - built in weights format in depthwise conv 2d op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide new weights formats in mkl dnn conv ops

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide new weights formats in cuda conv helpers

Signed-off-by: Yurii <iuriish@yahoo.com>

* - working with new weights format in cudnn conv api

Signed-off-by: Yurii <iuriish@yahoo.com>

* - take into account order of arrays in cudnn tensor descriptions

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide new weights formats in cpu conv3d (ff/bp)

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide new weights formats in cpu deconv3d (ff/bp)

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide new weights formats in conv3d ops (ff/bp) based on mkl api

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide new weights formats in conv3d ops (ff/bp) based on cudnn api

Signed-off-by: Yurii <iuriish@yahoo.com>

* - resolve conflicts 2

Signed-off-by: Yurii <iuriish@yahoo.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
Yurii Shyrma 2020-03-20 11:11:27 +02:00 committed by GitHub
parent 5dae4069cf
commit e700b59f80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 2185 additions and 1074 deletions

View File

@ -4076,7 +4076,7 @@ INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShap
// *** FIRST STAGE - exclude unity dimensions from oldShapeInfo and newShapeInfo (if such are present of course), since they don't affect on strides evaluation, however they complicate code // *** FIRST STAGE - exclude unity dimensions from oldShapeInfo and newShapeInfo (if such are present of course), since they don't affect on strides evaluation, however they complicate code
// FIXME - indeed we don't need to allocate so large memory amount (2*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities) // FIXME - indeed we don't need to allocate so large memory amount (4*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities)
Nd4jLong tempBuffer[4*MAX_RANK]; Nd4jLong tempBuffer[4*MAX_RANK];
Nd4jLong *oldShape = tempBuffer, *newShape = tempBuffer + 2*MAX_RANK, *oldStrides, *newStrides; Nd4jLong *oldShape = tempBuffer, *newShape = tempBuffer + 2*MAX_RANK, *oldStrides, *newStrides;

View File

@ -34,7 +34,7 @@ namespace ops {
CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_NULLIFIED(0); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW) auto output = OUTPUT_NULLIFIED(0); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW)
@ -45,12 +45,13 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
int dW = INT_ARG(3); // dilations width int dW = INT_ARG(3); // dilations width
int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL
int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 0-NCW, 1-NWC int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 0-NCW, 1-NWC
int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC]
const int rank = 3; const int rank = 3;
REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf());
REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf()); REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf());
int indIOioC, indIiW, indWoC(2); int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0);
if(!isNCW) { if(!isNCW) {
indIOioC = 2; indIiW = 1; indIOioC = 2; indIiW = 1;
} }
@ -63,7 +64,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
int iC = input->sizeAt(indIOioC); // input channels int iC = input->sizeAt(indIOioC); // input channels
int oC = weights->sizeAt(indWoC); // output channels int oC = weights->sizeAt(indWoC); // output channels
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = 0 == wFormat ? std::vector<Nd4jLong>({kW, iC, oC}) : (1 == wFormat ? std::vector<Nd4jLong>({oC, iC, kW}) : std::vector<Nd4jLong>({oC, kW, iC}));
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
@ -83,11 +84,11 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
sd::ops::conv2d conv2d; sd::ops::conv2d conv2d;
const Nd4jStatus status = conv2d.execute({&inputReshaped, &weightsReshaped, bias}, {&outputReshaped}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {}); const Nd4jStatus status = conv2d.execute({&inputReshaped, &weightsReshaped, bias}, {&outputReshaped}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW, wFormat}, {});
if (status != ND4J_STATUS_OK) if (status != ND4J_STATUS_OK)
return status; return status;
// ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW); // ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW, wFormat);
return Status::OK(); return Status::OK();
} }
@ -105,8 +106,9 @@ DECLARE_SHAPE_FN(conv1d) {
int dW = INT_ARG(3); // dilations width int dW = INT_ARG(3); // dilations width
int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME
int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW
int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC]
int indIOioC, indIiW, indWoC(2); int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0);
if(!isNCW) { if(!isNCW) {
indIOioC = 2; indIiW = 1; indIOioC = 2; indIiW = 1;
} }
@ -123,7 +125,7 @@ DECLARE_SHAPE_FN(conv1d) {
int iC = inputShapeInfo[indIOioC+1]; // input channels int iC = inputShapeInfo[indIOioC+1]; // input channels
int oC = weightsShapeInfo[indWoC+1]; // output channels int oC = weightsShapeInfo[indWoC+1]; // output channels
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = 0 == wFormat ? std::vector<Nd4jLong>({kW, iC, oC}) : (1 == wFormat ? std::vector<Nd4jLong>({oC, iC, kW}) : std::vector<Nd4jLong>({oC, kW, iC}));
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo) if (biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
@ -163,12 +165,12 @@ DECLARE_TYPES(conv1d) {
CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon
auto gradW = OUTPUT_NULLIFIED(1); // [kW, iC, oC] always auto gradW = OUTPUT_NULLIFIED(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC]
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width
@ -177,12 +179,14 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
int dW = INT_ARG(3); // dilations width int dW = INT_ARG(3); // dilations width
int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL
int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW
int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC]
const int rank = 3; const int rank = 3;
REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf());
REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf()); REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf());
REQUIRE_TRUE(gradO->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradO->rankOf()); REQUIRE_TRUE(gradO->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradO->rankOf());
int indIOioC, indIiW, indWoC(2);
int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0);
if(!isNCW) { if(!isNCW) {
indIOioC = 2; indIiW = 1; indIOioC = 2; indIiW = 1;
} }
@ -199,7 +203,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW});
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = 0 == wFormat ? std::vector<Nd4jLong>({kW, iC, oC}) : (1 == wFormat ? std::vector<Nd4jLong>({oC, iC, kW}) : std::vector<Nd4jLong>({oC, kW, iC}));
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias) if(bias)
@ -222,11 +226,11 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC] auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC]
sd::ops::conv2d_bp conv2dBP; sd::ops::conv2d_bp conv2dBP;
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {}); auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW, wFormat}, {});
if (status != ND4J_STATUS_OK) if (status != ND4J_STATUS_OK)
return status; return status;
// ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW); // ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW, wFormat);
return Status::OK(); return Status::OK();
} }
@ -235,7 +239,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
DECLARE_SHAPE_FN(conv1d_bp) { DECLARE_SHAPE_FN(conv1d_bp) {
auto inputShapeInfo = inputShape->at(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) auto inputShapeInfo = inputShape->at(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
auto weightsShapeInfo = inputShape->at(1); // [kW, iC, oC] always auto weightsShapeInfo = inputShape->at(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC]
Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC]
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next
@ -250,8 +254,9 @@ DECLARE_SHAPE_FN(conv1d_bp) {
int dW = INT_ARG(3); // dilations width int dW = INT_ARG(3); // dilations width
int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME
int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW
int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC]
int indIOioC, indIiW, indWoC(2); int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0);
if(!isNCW) { if(!isNCW) {
indIOioC = 2; indIiW = 1; indIOioC = 2; indIiW = 1;
} }
@ -268,7 +273,7 @@ DECLARE_SHAPE_FN(conv1d_bp) {
ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW});
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = 0 == wFormat ? std::vector<Nd4jLong>({kW, iC, oC}) : (1 == wFormat ? std::vector<Nd4jLong>({oC, iC, kW}) : std::vector<Nd4jLong>({oC, kW, iC}));
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if(biasShapeInfo) if(biasShapeInfo)

View File

@ -37,7 +37,7 @@ namespace ops {
CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
@ -49,21 +49,22 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
int dH = INT_ARG(6); // dilations height int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat);
return Status::OK(); return Status::OK();
} }
@ -73,7 +74,7 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
DECLARE_SHAPE_FN(conv2d) { DECLARE_SHAPE_FN(conv2d) {
auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC] always auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
//output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) //output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
@ -86,6 +87,7 @@ DECLARE_SHAPE_FN(conv2d) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) height int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1)); // filter(kernel) width int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1)); // filter(kernel) width
@ -95,7 +97,7 @@ DECLARE_SHAPE_FN(conv2d) {
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]);
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]);
int indIOioC, indIiH, indWoC(3); int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0);
if(!isNCHW) { if(!isNCHW) {
indIOioC = 3; indIiH = 1; indIOioC = 3; indIiH = 1;
} }
@ -109,7 +111,7 @@ DECLARE_SHAPE_FN(conv2d) {
const int iC = inputShapeInfo[indIOioC+1]; // input channels const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int oC = weightsShapeInfo[indWoC+1]; // output channels const int oC = weightsShapeInfo[indWoC+1]; // output channels
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo) if (biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
@ -157,12 +159,12 @@ DECLARE_SHAPE_FN(conv2d) {
CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC] always auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
int kH = INT_ARG(0); // filter(kernel) height int kH = INT_ARG(0); // filter(kernel) height
@ -175,6 +177,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
@ -182,19 +185,19 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
int trueoH, trueoW; // true output height, width int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::vector<Nd4jLong>expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector<Nd4jLong>expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong>expectedWeightsShape = {kH, kW, iC, oC}; std::vector<Nd4jLong>expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias) if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat);
return Status::OK(); return Status::OK();
} }
@ -204,7 +207,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
DECLARE_SHAPE_FN(conv2d_bp) { DECLARE_SHAPE_FN(conv2d_bp) {
auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC] always auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC]
auto gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
@ -224,8 +227,9 @@ DECLARE_SHAPE_FN(conv2d_bp) {
const int dW = INT_ARG(7); // dilations width const int dW = INT_ARG(7); // dilations width
const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
int indIOioC, indIiH, indOoH, indWoC(3); int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 3 : 0);
if(!isNCHW) { if(!isNCHW) {
indIOioC = 3; indIiH = 1; indOoH = 1; indIOioC = 3; indIiH = 1; indOoH = 1;
} }
@ -243,7 +247,7 @@ DECLARE_SHAPE_FN(conv2d_bp) {
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if(biasShapeInfo) if(biasShapeInfo)
@ -264,7 +268,7 @@ DECLARE_SHAPE_FN(conv2d_bp) {
CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
auto gradIShape = INPUT_VARIABLE(0); // [4] auto gradIShape = INPUT_VARIABLE(0); // [4]
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
@ -279,6 +283,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
const int rank = gradO->rankOf(); const int rank = gradO->rankOf();
@ -295,17 +300,17 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
int trueoH, trueoW; // true output height, width int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat);
return Status::OK(); return Status::OK();
} }
@ -321,7 +326,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
DECLARE_SHAPE_FN(conv2d_input_bp) { DECLARE_SHAPE_FN(conv2d_input_bp) {
auto gradIShapeShapeInfo = inputShape->at(0); // [4] auto gradIShapeShapeInfo = inputShape->at(0); // [4]
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC] always auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
const int rank = 4; const int rank = 4;
@ -340,8 +345,9 @@ DECLARE_SHAPE_FN(conv2d_input_bp) {
const int dW = INT_ARG(7); // dilations width const int dW = INT_ARG(7); // dilations width
const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
int indIOioC, indIiH, indWoC(3), indOoH; int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH;
if(!isNCHW) { if(!isNCHW) {
indIOioC = 3; indIiH = 1; indOoH = 1; indIOioC = 3; indIiH = 1; indOoH = 1;
} }
@ -361,7 +367,7 @@ DECLARE_SHAPE_FN(conv2d_input_bp) {
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());

View File

@ -32,7 +32,7 @@ namespace ops {
CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
@ -52,14 +52,15 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
int dH = INT_ARG(10); // dilations height int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width int dW = INT_ARG(11); // dilations width
int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
@ -71,14 +72,24 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
std::vector<int> permutForOutput; std::vector<int> permutForOutput;
if (isNCDHW) if (isNCDHW)
permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC] permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
else else
input = new NDArray(input->permute({0,4,1,2,3})); input = new NDArray(input->permute({0,4,1,2,3}));
std::vector<int> wAxes;
if(0 == wFormat)
wAxes = {3,0,1,2};
else if(1 == wFormat)
wAxes = {1,2,3,4};
else
wAxes = {4,1,2,3};
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
// [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC] // [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC]
MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, {3,0,1,2}, permutForOutput); // [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, iC, kD, kH, kW] = [bS, oD, oH, oW, oC]
// [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, kD, kH, kW, iC] = [bS, oD, oH, oW, oC]
MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, wAxes, permutForOutput);
if(bias) if(bias)
// output->applyBroadcast(broadcast::Add, {indIOioC}, bias); // output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
@ -101,7 +112,7 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
DECLARE_SHAPE_FN(conv3dnew) { DECLARE_SHAPE_FN(conv3dnew) {
auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC] always auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth
@ -118,13 +129,14 @@ DECLARE_SHAPE_FN(conv3dnew) {
int dW = INT_ARG(11); // dilations width int dW = INT_ARG(11); // dilations width
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID; int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID;
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
const int rank = 5; const int rank = 5;
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo); REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo);
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo);
int indIOioC, indIiD, indWoC(4); int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0);
if(!isNCDHW) { if(!isNCDHW) {
indIOioC = 4; indIiD = 1; indIOioC = 4; indIiD = 1;
} }
@ -139,7 +151,7 @@ DECLARE_SHAPE_FN(conv3dnew) {
int iC = inputShapeInfo[indIOioC+1]; // input channels int iC = inputShapeInfo[indIOioC+1]; // input channels
int oC = weightsShapeInfo[indWoC+1]; // output channels int oC = weightsShapeInfo[indWoC+1]; // output channels
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo) if (biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
@ -174,12 +186,12 @@ DECLARE_SHAPE_FN(conv3dnew) {
CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
@ -200,17 +212,18 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
int dW = INT_ARG(11); // dilations width int dW = INT_ARG(11); // dilations width
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
int trueoD, trueoH, trueoW; // true output depth/height/width int trueoD, trueoH, trueoW; // true output depth/height/width
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D_BP OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D_BP OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias) if(bias)
@ -231,10 +244,25 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
gradOaxesForDot = {0,2,3,4}; // bS, oD, oH, oW gradOaxesForDot = {0,2,3,4}; // bS, oD, oH, oW
} }
std::vector<int> wPermut, colPermut;
if(0 == wFormat) {
wPermut = {3,0,1,2,4};
colPermut = {2,3,4,1,0,5,6,7};
}
else if(1 == wFormat) {
wPermut = {1,2,3,4,0};
colPermut = {1,2,3,4,0,5,6,7};
}
else {
wPermut = {4,1,2,3,0};
colPermut = {2,3,4,1,0,5,6,7};
}
// ----- calculation of gradW and gradB ----- // // ----- calculation of gradW and gradB ----- //
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC] MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, wPermut); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
//----- calculation of gradO -----// //----- calculation of gradO -----//
if(gradB) { if(gradB) {
@ -246,7 +274,10 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
} }
//----- calculation of gradI -----// //----- calculation of gradI -----//
MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2,3,4,1,0,5,6,7}); // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
// [oC, iC, kD, kH, kW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
// [oC, kD, kH, kW, iC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut);
ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
if(!isNCDHW) { if(!isNCDHW) {
@ -270,7 +301,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
DECLARE_SHAPE_FN(conv3dnew_bp) { DECLARE_SHAPE_FN(conv3dnew_bp) {
Nd4jLong* inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) Nd4jLong* inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
Nd4jLong* weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC] always Nd4jLong* weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC]
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
@ -288,6 +319,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
int dW = INT_ARG(11); // dilations width int dW = INT_ARG(11); // dilations width
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
const int rank = 5; const int rank = 5;
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
@ -295,7 +327,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo);
REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo); REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo);
int indIOioC, indIiD, indWoC(4); int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0);
if(!isNCDHW) { if(!isNCDHW) {
indIOioC = 4; indIiD = 1; indIOioC = 4; indIiD = 1;
} }
@ -314,7 +346,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2});
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if(biasShapeInfo) if(biasShapeInfo)

View File

@ -35,7 +35,7 @@ namespace ops {
CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
@ -53,12 +53,13 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC);
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
@ -66,6 +67,12 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
if(!isNCHW) if(!isNCHW)
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
std::vector<int> colPermut;
if(1 == wFormat)
colPermut = {1, 2, 3, 0, 4, 5};
else
colPermut = {2, 3, 1, 0, 4, 5};
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
@ -73,8 +80,9 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
//----- calculation of output -----// //----- calculation of output -----//
// NHWC: [kH, kW, oC, iC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW] // NHWC: [kH, kW, oC, iC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW]
// NCHW: [kH, kW, oC, iC] x [bS, iC, iH, iW] = [kH, kW, oC, bS, iH, iW] // NHWC: [iC, oC, kH, kW] x [bS, iH, iW, iC] = [oC, kH, kW, bS, iH, iW]
sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 1, 0, 4, 5}); // NHWC: [iC, kH, kW, oC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW]
sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, colPermut);
LaunchContext* ctx = block.launchContext(); LaunchContext* ctx = block.launchContext();
helpers::col2im(*ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, dW); // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW] helpers::col2im(*ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, dW); // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW]
@ -97,7 +105,7 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
DECLARE_SHAPE_FN(deconv2d) { DECLARE_SHAPE_FN(deconv2d) {
auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC] always auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
const int rank = 4; const int rank = 4;
@ -114,8 +122,9 @@ DECLARE_SHAPE_FN(deconv2d) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
int indIOioC, indIiH, indWoC(2); int indIOioC, indIiH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3));
if(!isNCHW) { if(!isNCHW) {
indIOioC = 3; indIiH = 1; indIOioC = 3; indIiH = 1;
} }
@ -129,7 +138,7 @@ DECLARE_SHAPE_FN(deconv2d) {
const int iC = inputShapeInfo[indIOioC+1]; // input channels const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int oC = weightsShapeInfo[indWoC+1]; // output channels const int oC = weightsShapeInfo[indWoC+1]; // output channels
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC);
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo) if (biasShapeInfo)
REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
@ -163,12 +172,12 @@ DECLARE_SHAPE_FN(deconv2d) {
CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
@ -186,16 +195,17 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
int trueoH, trueoW; // true output height, width int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias) if(bias)
@ -206,29 +216,34 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
} }
// ----- calculation of gradI -> pass it through conv2d_ff ----- //
// ----- calculation of gradI -> pass it through conv2d_ff ----- //
sd::ops::conv2d conv2d; sd::ops::conv2d conv2d;
const Nd4jStatus status = conv2d.execute({gradO, weights}, {gradI}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, !isNCHW}, {}); const Nd4jStatus status = conv2d.execute({gradO, weights}, {gradI}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, !isNCHW, wFormat}, {});
if (status != ND4J_STATUS_OK) if (status != ND4J_STATUS_OK)
return status; return status;
// -----prepare permutation arrays and axes for dot product ----- // // -----prepare permutation arrays and axes for dot product ----- //
std::vector<int> inputAxesForDot; std::vector<int> inputAxes;
if(!isNCHW) { if(!isNCHW) {
gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
inputAxesForDot = {0, 1, 2}; // bS, iH, iW inputAxes = {0, 1, 2}; // bS, iH, iW
} }
else else
inputAxesForDot = {0, 2, 3}; // bS, iH, iW inputAxes = {0, 2, 3}; // bS, iH, iW
std::vector<int> gradWAxes; // empty for wFormat = 1
if(0 == wFormat)
gradWAxes = {3, 2, 0, 1};
else if(2 == wFormat)
gradWAxes = {0, 3, 1, 2};
// ----- calculation of gradW ----- // // ----- calculation of gradW ----- //
NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext()); NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext());
LaunchContext* ctx = block.launchContext(); LaunchContext* ctx = block.launchContext();
helpers::im2col(*ctx, *gradO, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, oC, oH, oW] is convoluted to [bS, oC, kH, kW, iH, iW] helpers::im2col(*ctx, *gradO, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, oC, oH, oW] is convoluted to [bS, oC, kH, kW, iH, iW]
MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 4, 5}, {3, 2, 0, 1}); // [bS, iC, iH, iW]/[bS, iH, iW, iC] x [bS, oC, kH, kW, iH, iW] = [iC, oC, kH, kW] MmulHelper::tensorDot(input, &columns, gradW, inputAxes, {0, 4, 5}, gradWAxes); // [bS, iC, iH, iW]/[bS, iH, iW, iC] x [bS, oC, kH, kW, iH, iW] = [iC, oC, kH, kW]
// ----- calculation of gradB ----- // // ----- calculation of gradB ----- //
if(gradB) { if(gradB) {
@ -248,7 +263,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
DECLARE_SHAPE_FN(deconv2d_bp) { DECLARE_SHAPE_FN(deconv2d_bp) {
auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC] always auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC]
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
@ -267,8 +282,9 @@ DECLARE_SHAPE_FN(deconv2d_bp) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
int indIOioC, indIiH, indWoC(2), indOoH; int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3));
if(!isNCHW) { if(!isNCHW) {
indIOioC = 3; indIiH = 1; indOoH = 1; indIOioC = 3; indIiH = 1; indOoH = 1;
} }
@ -286,7 +302,7 @@ DECLARE_SHAPE_FN(deconv2d_bp) {
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC);
REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if(biasShapeInfo) if(biasShapeInfo)

View File

@ -32,10 +32,10 @@ namespace ops {
CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI) auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI)
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
@ -47,6 +47,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
const int rank = gradO->rankOf(); const int rank = gradO->rankOf();
@ -57,20 +58,19 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
// create empty conv2d input array // create empty conv2d input array
NDArray input(gradO->ordering(), gradIShape->asVectorT<Nd4jLong>(), gradO->dataType(), block.launchContext()); NDArray input(gradO->ordering(), gradIShape->asVectorT<Nd4jLong>(), gradO->dataType(), block.launchContext());
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
int trueoH, trueoW; // true output height, width int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat);
return Status::OK(); return Status::OK();
} }
@ -84,7 +84,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
DECLARE_SHAPE_FN(deconv2d_tf) { DECLARE_SHAPE_FN(deconv2d_tf) {
auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC] always auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto gradIShapeShapeInfo = inputShape->at(0); // [4] auto gradIShapeShapeInfo = inputShape->at(0); // [4]
const int rank = 4; const int rank = 4;
@ -103,8 +103,9 @@ DECLARE_SHAPE_FN(deconv2d_tf) {
const int dW = INT_ARG(7); // dilations width const int dW = INT_ARG(7); // dilations width
const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
int indIOioC, indIiH, indWoC(3), indOoH; int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH;
if(!isNCHW) { if(!isNCHW) {
indIOioC = 3; indIiH = 1; indOoH = 1; indIOioC = 3; indIiH = 1; indOoH = 1;
} }
@ -126,7 +127,7 @@ DECLARE_SHAPE_FN(deconv2d_tf) {
ConvolutionUtils::calcOutSizeDeconv2D(trueiH, trueiW, kH, kW, sH, sW, pH, pW, dH, dW, oH, oW, isSameMode); ConvolutionUtils::calcOutSizeDeconv2D(trueiH, trueiW, kH, kW, sH, sW, pH, pW, dH, dW, oH, oW, isSameMode);
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,trueiH,trueiW, 0,indIOioC,indIiH,indIiH+1}); std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,trueiH,trueiW, 0,indIOioC,indIiH,indIiH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
REQUIRE_TRUE(expectedGradIShape == gradIShape, 0, "CUSTOM DECONV2D_TF OP: wrong shape of array with output shape, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradIShape).c_str()); REQUIRE_TRUE(expectedGradIShape == gradIShape, 0, "CUSTOM DECONV2D_TF OP: wrong shape of array with output shape, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradIShape).c_str());
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());

View File

@ -32,7 +32,7 @@ namespace ops {
CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
@ -53,13 +53,14 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
int dH = INT_ARG(10); // dilations height int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC);
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
@ -67,16 +68,23 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
if(!isNCDHW) if(!isNCDHW)
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
std::vector<int> colPermut;
if(1 == wFormat)
colPermut = {1,2,3,4,0,5,6,7};
else
colPermut = {2,3,4,1,0,5,6,7};
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
NDArray columns(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext()); NDArray columns(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
//----- calculation of output -----// //----- calculation of output -----//
// NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW] // [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
// NCDHW: [kD, kH, kW, oC, iC] x [bS, iC, iD, iH, iW] = [kD, kH, kW, oC, bS, iD, iH, iW] // [iC, oC, kD, kH, kW] x [bS, iD, iH, iW, iC] = [oC, kD, kH, kW, bS, iD, iH, iW]
sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW] // [iC, kD, kH, kW, oC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW] sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, colPermut); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
//----- add biases if required -----// //----- add biases if required -----//
if(bias) if(bias)
@ -101,7 +109,7 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
DECLARE_SHAPE_FN(deconv3d) { DECLARE_SHAPE_FN(deconv3d) {
auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NDCHW) auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NDCHW)
auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC] always auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
const int rank = 5; const int rank = 5;
@ -122,8 +130,9 @@ DECLARE_SHAPE_FN(deconv3d) {
int dW = INT_ARG(11); // dilations width int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
int indIOioC, indIiD, indWoC(3); int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4));
if(!isNCDHW) { if(!isNCDHW) {
indIOioC = 4; indIiD = 1; indIOioC = 4; indIiD = 1;
} }
@ -138,7 +147,7 @@ DECLARE_SHAPE_FN(deconv3d) {
const int iC = inputShapeInfo[indIOioC+1]; // input channels const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int oC = weightsShapeInfo[indWoC+1]; // output channels const int oC = weightsShapeInfo[indWoC+1]; // output channels
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC);
REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo) if (biasShapeInfo)
REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo));
@ -174,12 +183,12 @@ DECLARE_SHAPE_FN(deconv3d) {
CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
@ -201,16 +210,17 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
int dW = INT_ARG(11); // dilations width int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
int trueoD, trueoH, trueoW; // true output height, width int trueoD, trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias) if(bias)
@ -221,7 +231,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
// ----- calculation of gradI -> pass it through conv3d_ff ----- // // ----- calculation of gradI -> pass it through conv3d_ff ----- //
sd::ops::conv3dnew conv3d; sd::ops::conv3dnew conv3d;
const Nd4jStatus status = conv3d.execute({gradO, weights}, {gradI}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, isSameMode, !isNCDHW}, {}); const Nd4jStatus status = conv3d.execute({gradO, weights}, {gradI}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, isSameMode, !isNCDHW, wFormat}, {});
if (status != ND4J_STATUS_OK) if (status != ND4J_STATUS_OK)
return status; return status;
@ -235,10 +245,16 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
else else
inputAxesForDot = {0, 2, 3, 4}; // bS, iD, iH, iW inputAxesForDot = {0, 2, 3, 4}; // bS, iD, iH, iW
std::vector<int> gradWAxes; // empty for wFormat = 1
if(0 == wFormat)
gradWAxes = {4,3,0,1,2};
else if(2 == wFormat)
gradWAxes = {0,4,1,2,3};
// ----- calculation of gradW ----- // // ----- calculation of gradW ----- //
auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext()); auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW] ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW]
MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, {4, 3, 0, 1, 2}); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW] MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, gradWAxes); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW]
// ----- calculation of gradB ----- // // ----- calculation of gradB ----- //
if(gradB) { if(gradB) {
@ -267,7 +283,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
DECLARE_SHAPE_FN(deconv3d_bp) { DECLARE_SHAPE_FN(deconv3d_bp) {
auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC] always auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC]
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
@ -290,8 +306,9 @@ DECLARE_SHAPE_FN(deconv3d_bp) {
int dW = INT_ARG(11); // dilations width int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
int indIOioC, indIiD, indWoC(3); int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4));
if(!isNCDHW) { if(!isNCDHW) {
indIOioC = 4; indIiD = 1; indIOioC = 4; indIiD = 1;
} }
@ -310,8 +327,8 @@ DECLARE_SHAPE_FN(deconv3d_bp) {
ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2});
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC);
REQUIRE_TRUE(shape::shapeEquals(5, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(5, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if(biasShapeInfo) if(biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));

View File

@ -32,7 +32,7 @@ namespace ops {
CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC
auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
@ -50,19 +50,20 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weights->sizeAt(indWmC); // channels multiplier mC = weights->sizeAt(indWmC); // channels multiplier
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC);
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat);
return Status::OK(); return Status::OK();
} }
@ -75,7 +76,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) {
DECLARE_SHAPE_FN(depthwise_conv2d) { DECLARE_SHAPE_FN(depthwise_conv2d) {
Nd4jLong* inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) Nd4jLong* inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
Nd4jLong* weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, mC] always Nd4jLong* weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] = iC*mC Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] = iC*mC
const int rank = 4; const int rank = 4;
@ -92,8 +93,9 @@ DECLARE_SHAPE_FN(depthwise_conv2d) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
int indIOioC, indIiH, indWmC(3); int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0);
if(!isNCHW) { if(!isNCHW) {
indIOioC = 3; indIiH = 1; indIOioC = 3; indIiH = 1;
} }
@ -109,7 +111,7 @@ DECLARE_SHAPE_FN(depthwise_conv2d) {
const int oC = iC*mC; // output channels const int oC = iC*mC; // output channels
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo) if (biasShapeInfo)
REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo));
@ -148,12 +150,12 @@ DECLARE_SHAPE_FN(depthwise_conv2d) {
CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) { CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
@ -170,23 +172,24 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weights->sizeAt(indWmC); // channels multiplier mC = weights->sizeAt(indWmC); // channels multiplier
int trueoH, trueoW; // correct output height, width int trueoH, trueoW; // correct output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias) if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat);
return Status::OK(); return Status::OK();
} }
@ -214,8 +217,9 @@ DECLARE_SHAPE_FN(depthwise_conv2d_bp) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
int indIOioC, indIiH, indWmC(3); int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0);
if(!isNCHW) { if(!isNCHW) {
indIOioC = 3; indIiH = 1; indIOioC = 3; indIiH = 1;
} }
@ -234,7 +238,7 @@ DECLARE_SHAPE_FN(depthwise_conv2d_bp) {
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if(biasShapeInfo) if(biasShapeInfo)

View File

@ -29,7 +29,7 @@ namespace ops {
CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [1, 1, iC, oC] always auto weights = INPUT_VARIABLE(1); // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, iH, iW, oC] (NHWC) or [bS, oC, iH, iW] (NCHW) auto output = OUTPUT_VARIABLE(0); // [bS, iH, iW, oC] (NHWC) or [bS, oC, iH, iW] (NCHW)
@ -47,18 +47,19 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
int pW = 0; // paddings width int pW = 0; // paddings width
int dH = 1; // dilations height int dH = 1; // dilations height
int dW = 1; // dilations width int dW = 1; // dilations width
int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 1 ? INT_ARG(1) : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC]
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::vector<Nd4jLong> expectedWeightsShape = {1, 1, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC);
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW); ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW, wFormat);
return Status::OK(); return Status::OK();
} }
@ -73,7 +74,7 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
DECLARE_SHAPE_FN(pointwise_conv2d) { DECLARE_SHAPE_FN(pointwise_conv2d) {
Nd4jLong* inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) Nd4jLong* inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
Nd4jLong* weightsShapeInfo = inputShape->at(1); // [1, 1, iC, oC] always Nd4jLong* weightsShapeInfo = inputShape->at(1); // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC]
Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
const int rank = 4; const int rank = 4;
@ -81,8 +82,9 @@ DECLARE_SHAPE_FN(pointwise_conv2d) {
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]);
int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 1 ? INT_ARG(1) : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC]
int indIOioC, indWoC(3); int indIOioC, indWoC(0 == wFormat ? 3 : 0);
if(!isNCHW) if(!isNCHW)
indIOioC = 3; indIOioC = 3;
else else
@ -92,7 +94,7 @@ DECLARE_SHAPE_FN(pointwise_conv2d) {
const int iC = inputShapeInfo[indIOioC+1]; // input channels const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int oC = weightsShapeInfo[indWoC+1]; // output channels const int oC = weightsShapeInfo[indWoC+1]; // output channels
std::vector<Nd4jLong> expectedWeightsShape = {1, 1, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC);
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo) if (biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));

View File

@ -33,8 +33,8 @@ namespace ops {
CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
NDArray *weightsDepth = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always NDArray *weightsDepth = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC
NDArray *output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) NDArray *output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
@ -66,17 +66,19 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
int dH = INT_ARG(6); // dilations height int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weightsDepth->sizeAt(indWmC); // channels multiplier mC = weightsDepth->sizeAt(indWmC); // channels multiplier
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC}; std::vector<Nd4jLong> expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str()); REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str());
if(weightsPoint) { if(weightsPoint) {
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC}; std::vector<Nd4jLong> expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC);
REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str()); REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str());
} }
if (bias) if (bias)
@ -84,11 +86,11 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
if (iC == 1) { if (iC == 1) {
nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n",""); nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n","");
ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat);
return Status::OK(); return Status::OK();
} }
ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat);
return Status::OK(); return Status::OK();
} }
@ -103,8 +105,8 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
DECLARE_SHAPE_FN(sconv2d) { DECLARE_SHAPE_FN(sconv2d) {
auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weightsDShapeInfo = inputShape->at(1); // [kH, kW, iC, mC] always auto weightsDShapeInfo = inputShape->at(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC] always Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
Nd4jLong* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr Nd4jLong* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr
if(block.width() == 3) if(block.width() == 3)
@ -135,8 +137,9 @@ DECLARE_SHAPE_FN(sconv2d) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
int indIOioC, indIiH, indWmC(3); int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0);
if(!isNCHW) { if(!isNCHW) {
indIOioC = 3; indIiH = 1; indIOioC = 3; indIiH = 1;
} }
@ -148,13 +151,13 @@ DECLARE_SHAPE_FN(sconv2d) {
const int iH = inputShapeInfo[indIiH+1]; // input height const int iH = inputShapeInfo[indIiH+1]; // input height
const int iW = inputShapeInfo[indIiH+2]; // input width const int iW = inputShapeInfo[indIiH+2]; // input width
const int iC = inputShapeInfo[indIOioC+1]; // input channels const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int mC = weightsDShapeInfo[indWmC+1]; // channel multiplier const int mC = weightsDShapeInfo[indWmC+1]; // channel multiplier
const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC+1] : iC*mC; // output channels (oC or iC*mC) const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC+1] : iC*mC; // output channels (oC or iC*mC)
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC}; std::vector<Nd4jLong> expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str());
if(weightsPShapeInfo) { if(weightsPShapeInfo) {
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC}; std::vector<Nd4jLong> expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC);
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str());
} }
if (biasShapeInfo) if (biasShapeInfo)
@ -195,13 +198,13 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
NDArray *gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next NDArray *gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
NDArray *weightsDepth = INPUT_VARIABLE(2); // [kH, kW, iC, mC] always NDArray *weightsDepth = INPUT_VARIABLE(2); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr
NDArray *gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon NDArray *gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
NDArray *gradWD = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always NDArray *gradWD = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
NDArray *gradWP = nullptr; // [1, 1, iC*mC, oC] always NDArray *gradWP = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
NDArray *gradB = nullptr; // [oC] NDArray *gradB = nullptr; // [oC]
if(block.width() == 4) { if(block.width() == 4) {
@ -244,17 +247,18 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weightsDepth->sizeAt(indWmC); // channels multiplier mC = weightsDepth->sizeAt(indWmC); // channels multiplier
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC}; std::vector<Nd4jLong> expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str()); REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str());
REQUIRE_TRUE(gradWD->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of gradWD array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(gradWD).c_str()); REQUIRE_TRUE(gradWD->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of gradWD array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(gradWD).c_str());
if(weightsPoint) { if(weightsPoint) {
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC}; std::vector<Nd4jLong> expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC);
REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str()); REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str());
REQUIRE_TRUE(gradWP->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of gradWP array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(gradWP).c_str()); REQUIRE_TRUE(gradWP->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of gradWP array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(gradWP).c_str());
} }
@ -274,12 +278,12 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
auto resultFFShape = isNCHW ? std::vector<Nd4jLong>({bS, mC*iC, oH, oW}) : std::vector<Nd4jLong>({bS, oH, oW, mC*iC}); auto resultFFShape = isNCHW ? std::vector<Nd4jLong>({bS, mC*iC, oH, oW}) : std::vector<Nd4jLong>({bS, oH, oW, mC*iC});
auto resultFF = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext()); auto resultFF = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext());
ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat);
auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
auto gradIDepth = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) auto gradIDepth = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH and oW=iW ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW, wFormat); // in this case oH=iH and oW=iW
gradO = gradIDepth; gradO = gradIDepth;
bias = gradB = nullptr; // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step bias = gradB = nullptr; // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step
@ -288,7 +292,7 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
} }
// ----- apply depthwise_conv2d_bp ----- // // ----- apply depthwise_conv2d_bp ----- //
ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat);
if(weightsPoint) if(weightsPoint)
delete gradO; delete gradO;
@ -301,8 +305,8 @@ DECLARE_SHAPE_FN(sconv2d_bp) {
auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradOShapeInfo = inputShape->at(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradOShapeInfo = inputShape->at(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto weightsDShapeInfo = inputShape->at(2); // [kH, kW, iC, mC] always auto weightsDShapeInfo = inputShape->at(2); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC] always Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
Nd4jLong* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr Nd4jLong* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr
if(block.width() == 4) { if(block.width() == 4) {
@ -335,8 +339,9 @@ DECLARE_SHAPE_FN(sconv2d_bp) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
int indIOioC, indIiH, indWmC(3); int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0);
if(!isNCHW) { if(!isNCHW) {
indIOioC = 3; indIiH = 1; indIOioC = 3; indIiH = 1;
} }
@ -356,10 +361,10 @@ DECLARE_SHAPE_FN(sconv2d_bp) {
std::vector<Nd4jLong> expectedGradOShapeInfo = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1}); std::vector<Nd4jLong> expectedGradOShapeInfo = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1});
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShapeInfo), 0, "SCONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShapeInfo).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShapeInfo), 0, "SCONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShapeInfo).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC}; std::vector<Nd4jLong> expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D_BP OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D_BP OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str());
if(weightsPShapeInfo) { if(weightsPShapeInfo) {
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC}; std::vector<Nd4jLong> expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC);
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D_BP OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D_BP OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str());
} }
if (biasShapeInfo) if (biasShapeInfo)

View File

@ -166,7 +166,7 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});

View File

@ -55,7 +55,7 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
@ -172,7 +172,7 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});

View File

@ -168,7 +168,7 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});

View File

@ -55,7 +55,7 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
@ -174,7 +174,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});

View File

@ -167,7 +167,7 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});

View File

@ -154,15 +154,24 @@ namespace sd {
} }
// evaluates sizes values and indexes using input and output arrays depending on data format // evaluates sizes values and indexes using input and output arrays depending on data format
static inline void getSizesAndIndexesConv2d(const bool isNCHW, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) { static inline void getSizesAndIndexesConv2d(const bool isNCHW, const int wFormat, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) {
getSizesAndIndexesConv2d(isNCHW, input.getShapeInfo(), output.getShapeInfo(), bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); getSizesAndIndexesConv2d(isNCHW, wFormat, input.getShapeInfo(), output.getShapeInfo(), bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
} }
static inline void getSizesAndIndexesConv2d(const bool isNCHW, const Nd4jLong* inShapeInfo, const Nd4jLong* outShapeInfo, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) { static inline void getSizesAndIndexesConv2d(const bool isNCHW, const int wFormat, const Nd4jLong* inShapeInfo, const Nd4jLong* outShapeInfo, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) {
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
// weights [kH, kW, iC, oC] always // weights [kH, kW, iC, oC] (wFormat = 0), [oC, iC, kH, kW] (wFormat = 1), [oC, kH, kW, iC] (wFormat = 2)
// output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
indWkH = 0; indWiC = 2; indWoC = 3;
if(0 == wFormat) {
indWkH = 0; indWiC = 2; indWoC = 3;
}
else if(1 == wFormat) {
indWkH = 2; indWiC = 1; indWoC = 0;
}
else {
indWkH = 1; indWiC = 3; indWoC = 0;
}
if(!isNCHW) { if(!isNCHW) {
indIOioC = 3; indIiH = 1; indOoH = 1; indIOioC = 3; indIiH = 1; indOoH = 1;
@ -181,12 +190,21 @@ namespace sd {
} }
// evaluates sizes values and indexes using input and output arrays depending on data format // evaluates sizes values and indexes using input and output arrays depending on data format
static inline void getSizesAndIndexesConv3d(const bool isNCDHW, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iD, int& iH, int& iW, int& oC, int& oD, int& oH, int& oW, int& indIOioC, int& indIOioD, int& indWiC, int& indWoC, int& indWkD) { static inline void getSizesAndIndexesConv3d(const bool isNCDHW, const int wFormat, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iD, int& iH, int& iW, int& oC, int& oD, int& oH, int& oW, int& indIOioC, int& indIOioD, int& indWiC, int& indWoC, int& indWkD) {
// input [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) // input [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
// weights [kD, kH, kW, iC, oC] (NDHWC) or [oC, iC, kD, kH, kW] (NCDHW) // weights [kD, kH, kW, iC, oC] (wFormat = 0), [oC, iC, kD, kH, kW] (wFormat = 1), [oC, kD, kH, kW, iC] (wFormat = 2)
// output [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) // output [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
indWkD = 0; indWiC = 3; indWoC = 4; if(0 == wFormat) {
indWkD = 0; indWiC = 3; indWoC = 4;
}
else if(1 == wFormat) {
indWkD = 2; indWiC = 1; indWoC = 0;
}
else {
indWkD = 1; indWiC = 4; indWoC = 0;
}
if(!isNCDHW) { if(!isNCDHW) {
indIOioC = 4; indIOioD = 1; indIOioC = 4; indIOioD = 1;
} }
@ -203,7 +221,6 @@ namespace sd {
oD = output.sizeAt(indIOioD); // output depth oD = output.sizeAt(indIOioD); // output depth
oH = output.sizeAt(indIOioD+1); // output height oH = output.sizeAt(indIOioD+1); // output height
oW = output.sizeAt(indIOioD+2); // output width oW = output.sizeAt(indIOioD+2); // output width
} }
// static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int paddingMode, int& pH, int& pW, int& dH, int& dW) { // static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int paddingMode, int& pH, int& pW, int& dH, int& dW) {
@ -254,19 +271,41 @@ namespace sd {
// } // }
// } // }
static void conv2d(sd::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); static std::vector<Nd4jLong> expectWeightsShape(const int wFormat, const int kH, const int kW, const int iC, const int oC) {
if(0 == wFormat)
return std::vector<Nd4jLong>({kH, kW, iC, oC});
if(1 == wFormat)
return std::vector<Nd4jLong>({oC, iC, kH, kW});
return std::vector<Nd4jLong>({oC, kH, kW, iC});
}
static std::vector<Nd4jLong> expectWeightsShape(const int wFormat, const int kD, const int kH, const int kW, const int iC, const int oC) {
if(0 == wFormat)
return std::vector<Nd4jLong>({kD, kH, kW, iC, oC});
if(1 == wFormat)
return std::vector<Nd4jLong>({oC, iC, kD, kH, kW});
return std::vector<Nd4jLong>({oC, kD, kH, kW, iC});
}
static void conv2d(sd::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat);
// static void conv2d(sd::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs); // static void conv2d(sd::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
// static void conv2dBP(sd::graph::Context & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs); // static void conv2dBP(sd::graph::Context & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs);
static void conv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); static void conv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat);
static void depthwiseConv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); static void depthwiseConv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat);
static void depthwiseConv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); static void depthwiseConv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat);
static void sconv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); static void sconv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat);
static void vol2col(sd::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); static void vol2col(sd::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);

View File

@ -258,10 +258,10 @@ namespace sd {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
// weights [kH, kW, iC, oC] always // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
// bias [oC] // bias [oC]
// output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
@ -278,7 +278,7 @@ namespace sd {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
@ -291,6 +291,14 @@ namespace sd {
else else
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC
std::vector<int> wAxes;
if(0 == wFormat)
wAxes = {0, 1, 2};
else if(1 == wFormat)
wAxes = {2, 3, 1};
else
wAxes = {1, 2, 3};
NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext());
NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW}
NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext());
@ -298,7 +306,7 @@ namespace sd {
//----- calculation of output -----// //----- calculation of output -----//
auto ctx = block.launchContext(); auto ctx = block.launchContext();
helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, {0,1,2}, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC]
//----- assign outTemp to output -----// //----- assign outTemp to output -----//
if(isNCHW) { if(isNCHW) {
@ -319,15 +327,15 @@ namespace sd {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
// weights [kH, kW, iC, oC] always // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
// bias [oC] // bias [oC]
// gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
// gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
// gradW [kH, kW, iC, oC] always // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
// gradB [oC] // gradB [oC]
// kH filter(kernel) height // kH filter(kernel) height
@ -343,7 +351,7 @@ namespace sd {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
@ -359,13 +367,28 @@ namespace sd {
gradOaxesForDot = {0, 2, 3}; // bS, oH, oW gradOaxesForDot = {0, 2, 3}; // bS, oH, oW
} }
std::vector<int> wPermut, colPermut;
if(0 == wFormat) {
wPermut = {2, 0, 1, 3};
colPermut = {2, 3, 1, 0, 4, 5};
}
else if(1 == wFormat) {
wPermut = {1, 2, 3, 0};
colPermut = {1, 2, 3, 0, 4, 5};
}
else {
wPermut = {3, 1, 2, 0};
colPermut = {2, 3, 1, 0, 4, 5};
}
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
// ----- calculation of gradW ----- // // ----- calculation of gradW ----- //
if(gradW) { if(gradW) {
auto ctx = block.launchContext(); auto ctx = block.launchContext();
helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, {2, 0, 1, 3}); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC]
} }
// ----- calculation of gradB ----- // // ----- calculation of gradB ----- //
@ -379,9 +402,12 @@ namespace sd {
} }
//----- calculation of gradI -----// //----- calculation of gradI -----//
sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2, 3, 1, 0, 4, 5}); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
// [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW]
// [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut);
helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
if(!isNCHW) { if(!isNCHW) {
delete input; delete input;
@ -391,10 +417,10 @@ namespace sd {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
// weights [kH, kW, iC, mC] always // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
// bias [oC] = iC*mC // bias [oC] = iC*mC
// output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
@ -411,23 +437,30 @@ namespace sd {
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weights->sizeAt(indWmC); // channels multiplier mC = weights->sizeAt(indWmC); // channels multiplier
std::vector<std::vector<Nd4jLong>> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] std::vector<std::vector<Nd4jLong>> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW]
std::vector<std::vector<Nd4jLong>> modifOutput; std::vector<std::vector<Nd4jLong>> modifOutput, modifWeights;
std::vector<Nd4jLong> outReShape; std::vector<Nd4jLong> outReShape;
if(!isNCHW) { if(!isNCHW) {
outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC]
modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
} }
else { else {
outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW]
modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
} }
if(0 == wFormat)
modifWeights = {{2,0,1,3},{iC,kH*kW,mC}};
else if(1 == wFormat)
modifWeights = {{1,2,3,0},{iC,kH*kW,mC}};
else
modifWeights = {{3,1,2,0},{iC,kH*kW,mC}};
if(paddingMode == 1) // SAME if(paddingMode == 1) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
@ -435,7 +468,7 @@ namespace sd {
NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false);
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
if(bias) if(bias)
// output->applyBroadcast(broadcast::Add, {indIOioC}, bias); // output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
@ -447,14 +480,14 @@ namespace sd {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
// input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
// weights [kH, kW, iC, mC] always // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
// bias [oC] = [iC*mC] // bias [oC] = [iC*mC]
// gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
// gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
// gradW [kH, kW, iC, mC] always // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
// gradB [oC] // gradB [oC]
// kH filter(kernel) height // kH filter(kernel) height
@ -470,19 +503,19 @@ namespace sd {
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weights->sizeAt(indWmC); // channels multiplier mC = weights->sizeAt(indWmC); // channels multiplier
std::vector<std::vector<Nd4jLong>> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] std::vector<std::vector<Nd4jLong>> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW]
std::vector<std::vector<Nd4jLong>> modifGradO1, modifGradO2; std::vector<std::vector<Nd4jLong>> modifGradO1, modifGradO2, modifWeights;
std::vector<Nd4jLong> gradOreShape; std::vector<Nd4jLong> gradOreShape;
if(!isNCHW) { if(!isNCHW) {
gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC]
modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
} }
else { else {
gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW]
@ -490,6 +523,13 @@ namespace sd {
modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
} }
if(0 == wFormat)
modifWeights = {{2,0,1,3},{iC,kH*kW,mC}};
else if(1 == wFormat)
modifWeights = {{1,2,3,0},{iC,kH*kW,mC}};
else
modifWeights = {{3,1,2,0},{iC,kH*kW,mC}};
if(paddingMode == 1) // SAME if(paddingMode == 1) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
@ -499,7 +539,7 @@ namespace sd {
// ----- calculation of gradW and gradB ----- // // ----- calculation of gradW and gradB ----- //
helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, {{2,0,1,3},{iC,kH*kW,mC}}); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC]
// ----- calculation of gradB ----- // // ----- calculation of gradB ----- //
if(gradB) { if(gradB) {
@ -513,8 +553,8 @@ namespace sd {
} }
//----- calculation of gradI -----// //----- calculation of gradI -----//
sd::MmulHelper::tensorDot(weights, gradO, &columns, {{2,0,1,3},{iC,kH*kW,mC}}, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW]
helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
if(!isNCHW) { if(!isNCHW) {
delete input; delete input;
@ -524,11 +564,11 @@ namespace sd {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
// weightsDepth [kH, kW, iC, mC] always // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
// weightsPoint [1, 1, iC*mC, oC] always // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
// bias [oC], oC = iC*mC if weightsPoint=nullptr // bias [oC], oC = iC*mC if weightsPoint=nullptr
// output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
@ -545,7 +585,7 @@ namespace sd {
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weightsDepth->sizeAt(indWmC); // channels multiplier mC = weightsDepth->sizeAt(indWmC); // channels multiplier
NDArray* outputDepth = output; NDArray* outputDepth = output;
@ -553,11 +593,11 @@ namespace sd {
outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext());
// ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- //
ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW); ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat);
// ----- perform pointwise convolution (oH = iH, oW = iW) ----- // // ----- perform pointwise convolution (oH = iH, oW = iW) ----- //
if (weightsPoint) { if (weightsPoint) {
ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW); // in this case oH=iH, oW=iW ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat); // in this case oH=iH, oW=iW
delete outputDepth; delete outputDepth;
} }
} }
@ -1772,20 +1812,20 @@ namespace sd {
void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
} }
void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
} }
void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
} }
void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
} }
void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
} }
void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES);

View File

@ -217,10 +217,10 @@ void ConvolutionUtils::col2vol(sd::graph::Context& block, const NDArray& col, ND
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
// weights [kH, kW, iC, oC] always // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
// bias [oC] // bias [oC]
// output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
@ -237,7 +237,7 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
@ -248,6 +248,14 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr
else else
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC
std::vector<int> wAxes;
if(0 == wFormat)
wAxes = {0, 1, 2};
else if(1 == wFormat)
wAxes = {2, 3, 1};
else
wAxes = {1, 2, 3};
NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext());
NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW}
NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext());
@ -255,7 +263,7 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr
//----- calculation of output -----// //----- calculation of output -----//
auto ctx = block.launchContext(); auto ctx = block.launchContext();
helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, {0,1,2}, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC]
//----- assign outTemp to output -----// //----- assign outTemp to output -----//
if(isNCHW) { if(isNCHW) {
@ -275,16 +283,16 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
// weights [kH, kW, iC, mC] always // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
// bias [oC] = iC*mC // bias [oC] = iC*mC
// output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
@ -301,23 +309,30 @@ static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, co
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weights->sizeAt(indWmC); // channels multiplier mC = weights->sizeAt(indWmC); // channels multiplier
std::vector<std::vector<Nd4jLong>> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] std::vector<std::vector<Nd4jLong>> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW]
std::vector<std::vector<Nd4jLong>> modifOutput; std::vector<std::vector<Nd4jLong>> modifOutput, modifWeights;
std::vector<Nd4jLong> outReShape; std::vector<Nd4jLong> outReShape;
if(!isNCHW) { if(!isNCHW) {
outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC]
modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW]
} }
else { else {
outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW]
modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
} }
if(0 == wFormat)
modifWeights = {{2,0,1,3},{iC,kH*kW,mC}};
else if(1 == wFormat)
modifWeights = {{1,2,3,0},{iC,kH*kW,mC}};
else
modifWeights = {{3,1,2,0},{iC,kH*kW,mC}};
if(paddingMode == 1) // SAME if(paddingMode == 1) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
@ -325,7 +340,7 @@ static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, co
NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false);
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
if(bias) if(bias)
// output->applyBroadcast(broadcast::Add, {indIOioC}, bias); // output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
@ -336,17 +351,17 @@ static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, co
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
// weightsDepth [kH, kW, iC, mC] always // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
// weightsPoint [1, 1, iC*mC, oC] always // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
// bias [oC], oC = iC*mC if weightsPoint=nullptr // bias [oC], oC = iC*mC if weightsPoint=nullptr
// output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
@ -363,7 +378,7 @@ static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDAr
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weightsDepth->sizeAt(indWmC); // channels multiplier mC = weightsDepth->sizeAt(indWmC); // channels multiplier
NDArray* outputDepth = output; NDArray* outputDepth = output;
@ -371,18 +386,18 @@ static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDAr
outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext());
// ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- //
ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW); ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat);
// ----- perform pointwise convolution (oH = iH, oW = iW) ----- // // ----- perform pointwise convolution (oH = iH, oW = iW) ----- //
if (weightsPoint) { if (weightsPoint) {
ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW); // in this case oH=iH, oW=iW ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat); // in this case oH=iH, oW=iW
delete outputDepth; delete outputDepth;
} }
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -1176,15 +1191,15 @@ void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, const NDArray& inp
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
// weights [kH, kW, iC, oC] always // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
// bias [oC] // bias [oC]
// gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
// gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
// gradW [kH, kW, iC, oC] always // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
// gradB [oC] // gradB [oC]
// kH filter(kernel) height // kH filter(kernel) height
@ -1200,7 +1215,7 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
@ -1214,13 +1229,27 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA
gradOaxesForDot = {0, 2, 3}; // bS, oH, oW gradOaxesForDot = {0, 2, 3}; // bS, oH, oW
} }
std::vector<int> wPermut, colPermut;
if(0 == wFormat) {
wPermut = {2, 0, 1, 3};
colPermut = {2, 3, 1, 0, 4, 5};
}
else if(1 == wFormat) {
wPermut = {1, 2, 3, 0};
colPermut = {1, 2, 3, 0, 4, 5};
}
else {
wPermut = {3, 1, 2, 0};
colPermut = {2, 3, 1, 0, 4, 5};
}
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
// ----- calculation of gradW ----- // // ----- calculation of gradW ----- //
if(gradW) { if(gradW) {
auto ctx = block.launchContext(); auto ctx = block.launchContext();
helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, {2, 0, 1, 3}); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC]
} }
// ----- calculation of gradB ----- // // ----- calculation of gradB ----- //
@ -1234,7 +1263,10 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA
} }
//----- calculation of gradI -----// //----- calculation of gradI -----//
sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2, 3, 1, 0, 4, 5}); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
// [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW]
// [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
@ -1245,20 +1277,20 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
// input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
// weights [kH, kW, iC, mC] always // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
// bias [oC] = [iC*mC] // bias [oC] = [iC*mC]
// gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
// gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
// gradW [kH, kW, iC, mC] always // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
// gradB [oC] // gradB [oC]
// kH filter(kernel) height // kH filter(kernel) height
@ -1274,11 +1306,11 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weights->sizeAt(indWmC); // channels multiplier mC = weights->sizeAt(indWmC); // channels multiplier
std::vector<std::vector<Nd4jLong>> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] std::vector<std::vector<Nd4jLong>> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW]
std::vector<std::vector<Nd4jLong>> modifGradO1, modifGradO2; std::vector<std::vector<Nd4jLong>> modifGradO1, modifGradO2, modifWeights;
std::vector<Nd4jLong> gradOreShape; std::vector<Nd4jLong> gradOreShape;
if(!isNCHW) { if(!isNCHW) {
@ -1294,6 +1326,13 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
} }
if(0 == wFormat)
modifWeights = {{2,0,1,3},{iC,kH*kW,mC}};
else if(1 == wFormat)
modifWeights = {{1,2,3,0},{iC,kH*kW,mC}};
else
modifWeights = {{3,1,2,0},{iC,kH*kW,mC}};
if(paddingMode == 1) // SAME if(paddingMode == 1) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
@ -1303,7 +1342,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
// ----- calculation of gradW and gradB ----- // // ----- calculation of gradW and gradB ----- //
helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, {{2,0,1,3},{iC,kH*kW,mC}}); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC]
// ----- calculation of gradB ----- // // ----- calculation of gradB ----- //
if(gradB) { if(gradB) {
@ -1316,7 +1355,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
} }
//----- calculation of gradI -----// //----- calculation of gradI -----//
sd::MmulHelper::tensorDot(weights, gradO, &columns, {{2,0,1,3},{iC,kH*kW,mC}}, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW]
helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
if(!isNCHW) { if(!isNCHW) {
@ -1326,8 +1365,8 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) {
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES);
} }

View File

@ -102,7 +102,7 @@ PLATFORM_IMPL(avgpool2d_bp, ENGINE_CUDA) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});

View File

@ -54,7 +54,7 @@ PLATFORM_IMPL(avgpool3dnew, ENGINE_CUDA) {
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
@ -108,7 +108,7 @@ PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CUDA) {
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});

View File

@ -34,22 +34,25 @@ static void conv2dCUDNN(const LaunchContext* context,
const int sH, const int sW, const int sH, const int sW,
const int pH, const int pW, const int pH, const int pW,
const int dH, const int dW, const int dH, const int dW,
const int paddingMode, const bool isNCHW) { const int paddingMode, const bool isNCHW, const int wFormat) {
// cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC}
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", err); if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", err);
cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC);
// input descriptor // input descriptor
cudnnTensorDescriptor_t x; cudnnTensorDescriptor_t x;
cudnnCreateTensorDescriptor(&x); cudnnCreateTensorDescriptor(&x);
if(input->ews() == 1) if(input->ews() == 1 && input->ordering() == 'c')
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
else else
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
@ -58,13 +61,13 @@ static void conv2dCUDNN(const LaunchContext* context,
// weights descriptor // weights descriptor
cudnnFilterDescriptor_t w; cudnnFilterDescriptor_t w;
cudnnCreateFilterDescriptor(&w); cudnnCreateFilterDescriptor(&w);
err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), CUDNN_TENSOR_NCHW, oC, iC, kH, kW); err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), formatW, oC, iC, kH, kW);
if(err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetFilter4dDescriptor failed", err); if(err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetFilter4dDescriptor failed", err);
// output descriptor // output descriptor
cudnnTensorDescriptor_t z; cudnnTensorDescriptor_t z;
cudnnCreateTensorDescriptor(&z); cudnnCreateTensorDescriptor(&z);
if(output->ews() == 1) if(output->ews() == 1 && output->ordering() == 'c')
err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW);
else else
err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1));
@ -104,10 +107,10 @@ static void conv2dCUDNN(const LaunchContext* context,
// add bias if it is present // add bias if it is present
if (bias != nullptr) { if (bias != nullptr) {
cudnnTensorDescriptor_t b; cudnnTensorDescriptor_t b;
cudnnCreateTensorDescriptor(&b); cudnnCreateTensorDescriptor(&b);
err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); // err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf());
err = cudnnSetTensor4dDescriptor(b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1);
if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err);
err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer());
if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnAddTensor bias failed", err); if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnAddTensor bias failed", err);
@ -131,22 +134,23 @@ static void conv2dBpCUDNN(const LaunchContext* context,
const int sH, const int sW, const int sH, const int sW,
const int pH, const int pW, const int pH, const int pW,
const int dH, const int dW, const int dH, const int dW,
const int paddingMode, const bool isNCHW) { const int paddingMode, const bool isNCHW, const int wFormat) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: can't set stream for cuDNN", err); if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: can't set stream for cuDNN", err);
cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC);
// input descriptor // input descriptor
cudnnTensorDescriptor_t x; cudnnTensorDescriptor_t x;
cudnnCreateTensorDescriptor(&x); cudnnCreateTensorDescriptor(&x);
if(input->ews() == 1) if(input->ews() == 1 && input->ordering() == 'c')
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
else else
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
@ -155,7 +159,7 @@ static void conv2dBpCUDNN(const LaunchContext* context,
// gradO descriptor // gradO descriptor
cudnnTensorDescriptor_t dz; cudnnTensorDescriptor_t dz;
cudnnCreateTensorDescriptor(&dz); cudnnCreateTensorDescriptor(&dz);
if(gradO->ews() == 1) if(gradO->ews() == 1 && gradO->ordering() == 'c')
err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW);
else else
err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1));
@ -164,7 +168,7 @@ static void conv2dBpCUDNN(const LaunchContext* context,
// gradI descriptor // gradI descriptor
cudnnTensorDescriptor_t dx; cudnnTensorDescriptor_t dx;
cudnnCreateTensorDescriptor(&dx); cudnnCreateTensorDescriptor(&dx);
if(gradI->ews() == 1) if(gradI->ews() == 1 && gradI->ordering() == 'c')
err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW);
else else
err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1)); err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1));
@ -173,7 +177,7 @@ static void conv2dBpCUDNN(const LaunchContext* context,
// gradW descriptor // gradW descriptor
cudnnFilterDescriptor_t dw; cudnnFilterDescriptor_t dw;
cudnnCreateFilterDescriptor(&dw); cudnnCreateFilterDescriptor(&dw);
err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), CUDNN_TENSOR_NCHW, oC, iC, kH, kW); err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), formatW, oC, iC, kH, kW);
if(err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err); if(err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err);
// description of convolution // description of convolution
@ -220,7 +224,8 @@ static void conv2dBpCUDNN(const LaunchContext* context,
if(gradB != nullptr) { if(gradB != nullptr) {
cudnnTensorDescriptor_t db; cudnnTensorDescriptor_t db;
cudnnCreateTensorDescriptor(&db); cudnnCreateTensorDescriptor(&db);
err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); // err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf());
err = cudnnSetTensor4dDescriptor(db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1);
if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err);
err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer()); err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer());
@ -251,7 +256,7 @@ static void conv2dBpCUDNN(const LaunchContext* context,
PLATFORM_IMPL(conv2d, ENGINE_CUDA) { PLATFORM_IMPL(conv2d, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
@ -263,7 +268,8 @@ PLATFORM_IMPL(conv2d, ENGINE_CUDA) {
int dH = INT_ARG(6); // dilations height int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
@ -273,31 +279,35 @@ PLATFORM_IMPL(conv2d, ENGINE_CUDA) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias) { if (bias) {
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
REQUIRE_TRUE((bias->rankOf() == 1 && bias->strideAt(0) == 1) || (bias->rankOf() == 2 && bias->sizeAt(0) == 1 && bias->strideAt(1) == 1) || (bias->rankOf() == 2 && bias->sizeAt(1) == 1 && bias->strideAt(0) == 1), 0, "CUSTOM CONV2D CUDNN OP: bias array should be contiguous in memory !"); REQUIRE_TRUE((bias->rankOf() == 1 && bias->strideAt(0) == 1) || (bias->rankOf() == 2 && bias->sizeAt(0) == 1 && bias->strideAt(1) == 1) || (bias->rankOf() == 2 && bias->sizeAt(1) == 1 && bias->strideAt(0) == 1), 0, "CUSTOM CONV2D CUDNN OP: bias array should be contiguous in memory !");
} }
NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} NDArray* newWeights = weights; // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC}
newWeights->assign(weights->permute({3,2,0,1})); // permute weights (kH, kW, iC, oC --> oC, iC, kH, kW) if(0 == wFormat) {
newWeights = new NDArray(weights->ordering(), isNCHW ? std::vector<Nd4jLong>({oC, iC, kH, kW}) : std::vector<Nd4jLong>({oC, kH, kW, iC}), weights->dataType(), weights->getContext());
newWeights->assign(weights->permute(isNCHW ? std::vector<int>({3,2,0,1}) : std::vector<int>({3,0,1,2}))); // (kH, kW, iC, oC --> oC, iC, kH, kW) or (kH, kW, iC, oC --> oC, kH, kW, iC)
}
NDArray* newInput = input; NDArray* newInput = input;
NDArray* newGradI = nullptr; NDArray* newGradI = nullptr;
if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings
checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW);
conv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW, paddingMode, isNCHW); conv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW, paddingMode, isNCHW, wFormat);
if(newInput != input) if(newInput != input)
delete newInput; delete newInput;
delete newWeights; if(0 == wFormat)
delete newWeights;
return Status::OK(); return Status::OK();
} }
@ -322,12 +332,12 @@ PLATFORM_CHECK(conv2d, ENGINE_CUDA) {
PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
int kH = INT_ARG(0); // filter(kernel) height int kH = INT_ARG(0); // filter(kernel) height
@ -340,6 +350,7 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
@ -347,7 +358,7 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
int trueoH, trueoW; // true output height, width int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
@ -355,26 +366,30 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) {
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias) if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
NDArray* newGradW = new NDArray(gradW->ordering(), {oC, iC, kH, kW}, gradW->dataType(), gradW->getContext()); // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC} NDArray *newWeights = weights, *newGradW = gradW; // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC}
NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kH, kW}, weights->dataType(), weights->getContext()); if(0 == wFormat) {
newGradW = new NDArray(gradW->ordering(), isNCHW ? std::vector<Nd4jLong>({oC, iC, kH, kW}) : std::vector<Nd4jLong>({oC, kH, kW, iC}), gradW->dataType(), gradW->getContext());
newWeights->assign(weights->permute({3,2,0,1})); // permute weights (kH, kW, iC, oC --> oC, iC, kH, kW) newWeights = new NDArray(weights->ordering(), isNCHW ? std::vector<Nd4jLong>({oC, iC, kH, kW}) : std::vector<Nd4jLong>({oC, kH, kW, iC}), weights->dataType(), weights->getContext());
newWeights->assign(weights->permute(isNCHW ? std::vector<int>({3,2,0,1}) : std::vector<int>({3,0,1,2}))); // (kH, kW, iC, oC --> oC, iC, kH, kW) or (kH, kW, iC, oC --> oC, kH, kW, iC)
}
NDArray* newInput = input; NDArray* newInput = input;
NDArray* newGradI = gradI; NDArray* newGradI = gradI;
if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings
checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW);
conv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW); conv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW,wFormat);
newGradW->permutei({2,3,1,0}); // [oC, iC, kH, kW] -> [kH, kW, iC, oC] if(0 == wFormat) {
gradW->assign(newGradW); newGradW->permutei(isNCHW ? std::vector<int>({2,3,1,0}) : std::vector<int>({1,2,3,0})); // (oC, iC, kH, kW --> kH, kW, iC, oC) or (oC, kH, kW, iC --> kH, kW, iC, oC)
gradW->assign(newGradW);
}
if(newInput != input) { if(newInput != input) {
@ -387,8 +402,10 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) {
delete newGradI; delete newGradI;
} }
delete newWeights; if(0 == wFormat) {
delete newGradW; delete newWeights;
delete newGradW;
}
return Status::OK(); return Status::OK();
} }

View File

@ -34,13 +34,15 @@ static void conv3dCUDNN(const LaunchContext* context,
const int sD, const int sH, const int sW, const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW, const int pD, const int pH, const int pW,
const int dD, const int dH, const int dW, const int dD, const int dH, const int dW,
const int paddingMode, const bool isNCDHW) { const int paddingMode, const bool isNCDHW, const int wFormat) {
// cudnn support only one format for weights {oC,iC,kD,kH,kW}
const int numDims = 5; const int numDims = 5;
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
@ -53,7 +55,7 @@ static void conv3dCUDNN(const LaunchContext* context,
const std::vector<int> xShape = {bS, iC, iD, iH, iW}; const std::vector<int> xShape = {bS, iC, iD, iH, iW};
const std::vector<int> zShape = {bS, oC, oD, oH, oW}; const std::vector<int> zShape = {bS, oC, oD, oH, oW};
const std::vector<int> wShape = {oC, iC, kD, kH, kW}; const std::vector<int> wShape = {oC, iC, kD, kH, kW};
const std::vector<int> bShape = {1, (isNCDHW ? oC : 1), 1, 1, (isNCDHW ? 1 : oC)}; const std::vector<int> bShape = {1, oC, 1, 1, 1}; // {1, (isNCDHW ? oC : 1), 1, 1, (isNCDHW ? 1 : oC)};
const std::vector<int> xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; const std::vector<int> xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)};
const std::vector<int> zStrides = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)}; const std::vector<int> zStrides = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)};
@ -120,7 +122,7 @@ static void conv3dCUDNN(const LaunchContext* context,
cudnnTensorDescriptor_t b; cudnnTensorDescriptor_t b;
cudnnCreateTensorDescriptor(&b); cudnnCreateTensorDescriptor(&b);
err = cudnnSetTensorNdDescriptorEx(b, format, cudnnDataType(bias->dataType()), numDims, bShape.data()); err = cudnnSetTensorNdDescriptorEx(b, /*format*/CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), numDims, bShape.data());
if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor for bias failed", err); if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor for bias failed", err);
err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer());
if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnAddTensor bias failed", err); if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnAddTensor bias failed", err);
@ -144,13 +146,15 @@ static void conv3dBpCUDNN(const LaunchContext* context,
const int sD, const int sH, const int sW, const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW, const int pD, const int pH, const int pW,
const int dD, const int dH, const int dW, const int dD, const int dH, const int dW,
const int paddingMode, const bool isNCDHW) { const int paddingMode, const bool isNCDHW, const int wFormat) {
// cudnn supports only two formats {oC,iC,kD,kH,kW} and {oC,kD,kH,kW,iC} for weights/gradW
const int numDims = 5; const int numDims = 5;
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
@ -170,6 +174,7 @@ static void conv3dBpCUDNN(const LaunchContext* context,
const std::vector<int> dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)}; const std::vector<int> dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)};
cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC);
// input descriptor // input descriptor
cudnnTensorDescriptor_t x; cudnnTensorDescriptor_t x;
@ -201,7 +206,7 @@ static void conv3dBpCUDNN(const LaunchContext* context,
// gradW descriptor // gradW descriptor
cudnnFilterDescriptor_t dw; cudnnFilterDescriptor_t dw;
cudnnCreateFilterDescriptor(&dw); cudnnCreateFilterDescriptor(&dw);
err = cudnnSetFilterNdDescriptor(dw, cudnnDataType(gradW->dataType()), CUDNN_TENSOR_NCHW, numDims, wShape.data()); err = cudnnSetFilterNdDescriptor(dw, cudnnDataType(gradW->dataType()), formatW, numDims, wShape.data());
if(err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetFilterNdDescriptor failed", err); if(err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetFilterNdDescriptor failed", err);
// description of convolution // description of convolution
@ -280,7 +285,7 @@ static void conv3dBpCUDNN(const LaunchContext* context,
PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) { PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
@ -301,34 +306,39 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) {
int dW = INT_ARG(11); // dilations width int dW = INT_ARG(11); // dilations width
int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode);
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV3D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV3D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV3D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV3D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} NDArray* newWeights = weights; // cudnn support only one format {oC,iC,kD,kH,kW}
newWeights->assign(weights->permute({4,3,0,1,2})); // permute weights (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW) if(1 != wFormat) {
newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext());
newWeights->assign(weights->permute(0 == wFormat ? std::vector<int>({4,3,0,1,2}) : std::vector<int>({0,4,1,2,3}))); // kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW or oC, kD, kH, kW, iC --> oC, iC, kD, kH, kW
}
NDArray* newInput = input; NDArray* newInput = input;
NDArray* newGradI = nullptr; NDArray* newGradI = nullptr;
if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings
checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW);
conv3dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW, paddingMode, isNCDHW); conv3dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW, paddingMode, isNCDHW, wFormat);
if(newInput != input) if(newInput != input)
delete newInput; delete newInput;
delete newWeights; if(1 != wFormat)
delete newWeights;
return Status::OK(); return Status::OK();
} }
@ -337,7 +347,7 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) {
PLATFORM_CHECK(conv3dnew, ENGINE_CUDA) { PLATFORM_CHECK(conv3dnew, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID
@ -353,12 +363,12 @@ PLATFORM_CHECK(conv3dnew, ENGINE_CUDA) {
PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 5, 0, "CONV3D_BP CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 5, 0, "CONV3D_BP CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
@ -379,10 +389,11 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) {
int dW = INT_ARG(11); // dilations width int dW = INT_ARG(11); // dilations width
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
int trueoD, trueoH, trueoW; // true output depth/height/width int trueoD, trueoH, trueoW; // true output depth/height/width
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
@ -390,7 +401,7 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) {
REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D_BP CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D_BP CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV3D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV3D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(gradW->isSameShape(expectedWeightsShape), 0, "CONV3D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(gradW->isSameShape(expectedWeightsShape), 0, "CONV3D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias) if(bias)
@ -398,20 +409,25 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) {
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode);
NDArray* newGradW = new NDArray(gradW->ordering(), {oC, iC, kD, kH, kW}, gradW->dataType(), gradW->getContext()); // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC} NDArray *newWeights = weights, *newGradW = gradW; // cudnn support only two formats {oC,iC,kD,kH,kW} and {oC,kD,kH,kW,iC}
NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext()); if(0 == wFormat) {
newGradW = new NDArray(gradW->ordering(), isNCDHW ? std::vector<Nd4jLong>({oC, iC, kD, kH, kW}) : std::vector<Nd4jLong>({oC, kD, kH, kW, iC}), gradW->dataType(), gradW->getContext());
newWeights->assign(weights->permute({4,3,0,1,2})); // permute weights (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW) newWeights = new NDArray(weights->ordering(), isNCDHW ? std::vector<Nd4jLong>({oC, iC, kD, kH, kW}) : std::vector<Nd4jLong>({oC, kD, kH, kW, iC}), weights->dataType(), weights->getContext());
newWeights->assign(weights->permute(isNCDHW ? std::vector<int>({4,3,0,1,2}) : std::vector<int>({4,0,1,2,3}))); // (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW) or (kD, kH, kW, iC, oC --> oC, kD, kH, kW, iC)
}
NDArray* newInput = input; NDArray* newInput = input;
NDArray* newGradI = gradI; NDArray* newGradI = gradI;
if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings
checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW);
conv3dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW,paddingMode,isNCDHW); conv3dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW,paddingMode,isNCDHW,wFormat);
if(0 == wFormat) {
newGradW->permutei(isNCDHW ? std::vector<int>({2,3,4,1,0}) : std::vector<int>({1,2,3,4,0})); // (oC, iC, kD, kH, kW --> kD, kH, kW, iC, oC) or (oC, kD, kH, kW, iC --> kD, kH, kW, iC, oC)
gradW->assign(newGradW);
}
newGradW->permutei({2,3,4,1,0}); // [oC, iC, kD, kH, kW] -> [kD, kH, kW, iC, oC]
gradW->assign(newGradW);
if(newInput != input) { if(newInput != input) {
@ -424,8 +440,10 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) {
delete newGradI; delete newGradI;
} }
delete newWeights; if(0 == wFormat) {
delete newGradW; delete newWeights;
delete newGradW;
}
return Status::OK(); return Status::OK();
} }
@ -433,7 +451,7 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) {
PLATFORM_CHECK(conv3dnew_bp, ENGINE_CUDA) { PLATFORM_CHECK(conv3dnew_bp, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next

View File

@ -124,7 +124,7 @@ void pooling2dCUDNN(const LaunchContext* context,
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
@ -135,7 +135,7 @@ void pooling2dCUDNN(const LaunchContext* context,
// input descriptor // input descriptor
cudnnTensorDescriptor_t x; cudnnTensorDescriptor_t x;
cudnnCreateTensorDescriptor(&x); cudnnCreateTensorDescriptor(&x);
if(input->ews() == 1) if(input->ews() == 1 && input->ordering() == 'c')
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
else else
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
@ -144,7 +144,7 @@ void pooling2dCUDNN(const LaunchContext* context,
// output descriptor // output descriptor
cudnnTensorDescriptor_t z; cudnnTensorDescriptor_t z;
cudnnCreateTensorDescriptor(&z); cudnnCreateTensorDescriptor(&z);
if(output->ews() == 1) if(output->ews() == 1 && output->ordering() == 'c')
err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW);
else else
err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1));
@ -187,7 +187,7 @@ void pooling2dBpCUDNN(const LaunchContext* context,
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
@ -198,7 +198,7 @@ void pooling2dBpCUDNN(const LaunchContext* context,
// input and gradI descriptor // input and gradI descriptor
cudnnTensorDescriptor_t x; cudnnTensorDescriptor_t x;
cudnnCreateTensorDescriptor(&x); cudnnCreateTensorDescriptor(&x);
if(input->ews() == 1) if(input->ews() == 1 && input->ordering() == 'c')
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
else else
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
@ -207,7 +207,7 @@ void pooling2dBpCUDNN(const LaunchContext* context,
// gradO descriptor // gradO descriptor
cudnnTensorDescriptor_t dz; cudnnTensorDescriptor_t dz;
cudnnCreateTensorDescriptor(&dz); cudnnCreateTensorDescriptor(&dz);
if(gradO->ews() == 1) if(gradO->ews() == 1 && gradO->ordering() == 'c')
err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW);
else else
err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1));
@ -255,7 +255,7 @@ void pooling3dCUDNN(const LaunchContext* context,
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
const int pSizes[] = {pD, pH, pW}; const int pSizes[] = {pD, pH, pW};
const int sSizes[] = {sD, sH, sW}; const int sSizes[] = {sD, sH, sW};
@ -272,7 +272,7 @@ void pooling3dCUDNN(const LaunchContext* context,
// input descriptor // input descriptor
cudnnTensorDescriptor_t x; cudnnTensorDescriptor_t x;
cudnnCreateTensorDescriptor(&x); cudnnCreateTensorDescriptor(&x);
if(input->ews() == 1) if(input->ews() == 1 && input->ordering() == 'c')
err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape);
else else
err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides);
@ -281,7 +281,7 @@ void pooling3dCUDNN(const LaunchContext* context,
// output descriptor // output descriptor
cudnnTensorDescriptor_t z; cudnnTensorDescriptor_t z;
cudnnCreateTensorDescriptor(&z); cudnnCreateTensorDescriptor(&z);
if(output->ews() == 1) if(output->ews() == 1 && output->ordering() == 'c')
err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape); err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape);
else else
err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape, zStrides); err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape, zStrides);
@ -330,7 +330,7 @@ void pooling3dBpCUDNN(const LaunchContext* context,
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
const int pSizes[] = {pD, pH, pW}; const int pSizes[] = {pD, pH, pW};
const int sSizes[] = {sD, sH, sW}; const int sSizes[] = {sD, sH, sW};
@ -347,7 +347,7 @@ void pooling3dBpCUDNN(const LaunchContext* context,
// input and gradI descriptor // input and gradI descriptor
cudnnTensorDescriptor_t x; cudnnTensorDescriptor_t x;
cudnnCreateTensorDescriptor(&x); cudnnCreateTensorDescriptor(&x);
if(input->ews() == 1) if(input->ews() == 1 && input->ordering() == 'c')
err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape);
else else
err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides);
@ -356,7 +356,7 @@ void pooling3dBpCUDNN(const LaunchContext* context,
// gradO descriptor // gradO descriptor
cudnnTensorDescriptor_t dz; cudnnTensorDescriptor_t dz;
cudnnCreateTensorDescriptor(&dz); cudnnCreateTensorDescriptor(&dz);
if(gradO->ews() == 1) if(gradO->ews() == 1 && gradO->ordering() == 'c')
err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape); err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape);
else else
err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape, dzStrides); err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape, dzStrides);

View File

@ -39,14 +39,14 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context,
// cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC)
// input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc
// weights [iC, mC, kH, kW], mkl doesn't support this format, so we'll make permute // weights [iC, mC, kH, kW]
// bias [oC], may be nullptr // bias [oC], may be nullptr
// output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc
// oC = iC*mC // oC = iC*mC
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weights->sizeAt(1); mC = weights->sizeAt(1);
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
@ -58,7 +58,7 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context,
// input descriptor // input descriptor
cudnnTensorDescriptor_t x; cudnnTensorDescriptor_t x;
cudnnCreateTensorDescriptor(&x); cudnnCreateTensorDescriptor(&x);
if(input->ews() == 1) if(input->ews() == 1 && input->ordering() == 'c')
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
else else
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
@ -73,7 +73,7 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context,
// output descriptor // output descriptor
cudnnTensorDescriptor_t z; cudnnTensorDescriptor_t z;
cudnnCreateTensorDescriptor(&z); cudnnCreateTensorDescriptor(&z);
if(output->ews() == 1) if(output->ews() == 1 && output->ordering() == 'c')
err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW);
else else
err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1));
@ -117,7 +117,8 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context,
cudnnTensorDescriptor_t b; cudnnTensorDescriptor_t b;
cudnnCreateTensorDescriptor(&b); cudnnCreateTensorDescriptor(&b);
err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); // err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf());
err = cudnnSetTensor4dDescriptor(b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1);
if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err);
err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer());
if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnAddTensor bias failed", err); if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnAddTensor bias failed", err);
@ -146,14 +147,14 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context,
// cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC)
// input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc
// weights, gradW [iC, mC, kH, kW], mkl doesn't support this format, so we'll make permute // weights, gradW [iC, mC, kH, kW]
// gradB [oC], may be nullptr // gradB [oC], may be nullptr
// gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc
// oC = iC*mC // oC = iC*mC
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weights->sizeAt(1); mC = weights->sizeAt(1);
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
@ -165,7 +166,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context,
// input descriptor // input descriptor
cudnnTensorDescriptor_t x; cudnnTensorDescriptor_t x;
cudnnCreateTensorDescriptor(&x); cudnnCreateTensorDescriptor(&x);
if(input->ews() == 1) if(input->ews() == 1 && input->ordering() == 'c')
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
else else
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
@ -174,7 +175,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context,
// gradO descriptor // gradO descriptor
cudnnTensorDescriptor_t dz; cudnnTensorDescriptor_t dz;
cudnnCreateTensorDescriptor(&dz); cudnnCreateTensorDescriptor(&dz);
if(gradO->ews() == 1) if(gradO->ews() == 1 && gradO->ordering() == 'c')
err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW);
else else
err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1));
@ -183,7 +184,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context,
// gradI descriptor // gradI descriptor
cudnnTensorDescriptor_t dx; cudnnTensorDescriptor_t dx;
cudnnCreateTensorDescriptor(&dx); cudnnCreateTensorDescriptor(&dx);
if(gradI->ews() == 1) if(gradI->ews() == 1 && gradI->ordering() == 'c')
err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW);
else else
err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1)); err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1));
@ -241,7 +242,8 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context,
if(gradB != nullptr) { if(gradB != nullptr) {
cudnnTensorDescriptor_t db; cudnnTensorDescriptor_t db;
cudnnCreateTensorDescriptor(&db); cudnnCreateTensorDescriptor(&db);
err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); // err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf());
err = cudnnSetTensor4dDescriptor(db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1);
if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err);
err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer()); err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer());
@ -272,7 +274,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context,
PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) { PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
@ -290,22 +292,31 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weights->sizeAt(indWmC); // channels multiplier mC = weights->sizeAt(indWmC); // channels multiplier
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "DEPTHWISECONV2D CUDNN OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "DEPTHWISECONV2D CUDNN OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC);
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support format {oC, iC/groupCount, kH, kW} std::vector<int> wPermut; // cudnn support format {oC, iC/groupCount, kH, kW} only, mC = 1, oC = iC (groupCount == iC) that is {iC, mC, kH, kW} in our case
newWeights->assign(weights->permute({2,3,0,1})); // assign permuted weights (kH, kW, iC, mC --> iC, mC, kH, kW) if(0 == wFormat)
wPermut = {2,3,0,1}; // kH, kW, iC, mC -> iC, mC, kH, kW
else if(1 == wFormat)
wPermut = {1,0,2,3}; // mC, iC, kH, kW -> iC, mC, kH, kW
else
wPermut = {3,0,1,2}; // mC, kH, kW, iC -> iC, mC, kH, kW
NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext());
newWeights->assign(weights->permute(wPermut));
NDArray* newInput = input; NDArray* newInput = input;
NDArray* newGradI = nullptr; NDArray* newGradI = nullptr;
@ -326,12 +337,13 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) {
PLATFORM_CHECK(depthwise_conv2d, ENGINE_CUDA) { PLATFORM_CHECK(depthwise_conv2d, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC
const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL
const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
const int mC = weights->sizeAt(3); const int mC = weights->sizeAt(0 == wFormat ? 3 : 0);
const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF;
const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF;
@ -344,12 +356,12 @@ PLATFORM_CHECK(depthwise_conv2d, ENGINE_CUDA) {
PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
@ -366,10 +378,11 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weights->sizeAt(indWmC); // channels multiplier mC = weights->sizeAt(indWmC); // channels multiplier
int trueoH, trueoW; // correct output height, width int trueoH, trueoW; // correct output height, width
@ -378,17 +391,30 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) {
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias) if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
std::vector<int> wPermut, gradWPermut; // cudnn support format {oC, iC/groupCount, kH, kW} only, mC = 1, oC = iC (groupCount == iC) that is {iC, mC, kH, kW}
if(0 == wFormat) {
wPermut = {2,3,0,1}; // kH, kW, iC, mC -> iC, mC, kH, kW
gradWPermut = {2,3,0,1}; // iC, mC, kH, kW -> kH, kW, iC, mC
}
else if(1 == wFormat) {
wPermut = {1,0,2,3}; // mC, iC, kH, kW -> iC, mC, kH, kW
gradWPermut = {1,0,2,3}; // iC, mC, kH, kW -> mC, iC, kH, kW
}
else {
wPermut = {3,0,1,2}; // mC, kH, kW, iC -> iC, mC, kH, kW
gradWPermut = {1,2,3,0}; // iC, mC, kH, kW -> mC, kH, kW, iC
}
NDArray* newGradW = new NDArray(gradW->ordering(), {iC, mC, kH, kW}, gradW->dataType(), gradW->getContext()); // cudnn support format {oC, iC/groupCount, kH, kW} NDArray* newGradW = new NDArray(gradW->ordering(), {iC, mC, kH, kW}, gradW->dataType(), gradW->getContext());
NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext());
newWeights->assign(weights->permute({2,3,0,1})); // assign permuted weights (kH, kW, iC, mC --> iC, mC, kH, kW) newWeights->assign(weights->permute(wPermut));
NDArray* newInput = input; NDArray* newInput = input;
NDArray* newGradI = gradI; NDArray* newGradI = gradI;
@ -397,7 +423,7 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) {
depthwiseConv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW); depthwiseConv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW);
newGradW->permutei({2,3,0,1}); // [iC, mC, kH, kW] -> [kH, kW, iC, mC] newGradW->permutei(gradWPermut);
gradW->assign(newGradW); gradW->assign(newGradW);
if(newInput != input) { if(newInput != input) {
@ -420,14 +446,15 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) {
PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CUDA) { PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL
const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
const int mC = weights->sizeAt(3); const int mC = weights->sizeAt(0 == wFormat ? 3 : 0);
const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF;
const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF;

View File

@ -98,7 +98,7 @@ PLATFORM_IMPL(maxpool2d_bp, ENGINE_CUDA) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});

View File

@ -54,7 +54,7 @@ PLATFORM_IMPL(maxpool3dnew, ENGINE_CUDA) {
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
@ -106,7 +106,7 @@ PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CUDA) {
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});

View File

@ -60,7 +60,7 @@ PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
if (paddingMode) if (paddingMode)
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
@ -105,7 +105,7 @@ PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());

View File

@ -61,7 +61,7 @@ PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) {
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
if(paddingMode) // SAME if(paddingMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
@ -109,7 +109,7 @@ PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) {
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());

View File

@ -91,12 +91,12 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
mkldnnUtils::setBlockStrides(x, xRank, x_user_md); mkldnnUtils::setBlockStrides(x, x_user_md);
// z, output // z, output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
mkldnnUtils::setBlockStrides(z, xRank, z_user_md); mkldnnUtils::setBlockStrides(z, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -112,7 +112,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
// provide memory and check whether reorder is required // provide memory and check whether reorder is required
// x // x
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// z // z
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
@ -207,19 +207,19 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
mkldnnUtils::setBlockStrides(x, xRank, x_user_md); mkldnnUtils::setBlockStrides(x, x_user_md);
// dLdO // dLdO
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
mkldnnUtils::setBlockStrides(dLdO, xRank, dLdO_user_md); mkldnnUtils::setBlockStrides(dLdO, dLdO_user_md);
// dLdI // dLdI
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
mkldnnUtils::setBlockStrides(dLdI, xRank, dLdI_user_md); mkldnnUtils::setBlockStrides(dLdI, dLdI_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -239,10 +239,10 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
// provide memory and check whether reorder is required // provide memory and check whether reorder is required
// x // x
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_bp_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// dLdO // dLdO
mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, args, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
// mean // mean
auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());

View File

@ -38,13 +38,13 @@ namespace platforms {
static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
const NDArray *bias, NDArray *output, const NDArray *bias, NDArray *output,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
const int paddingMode, const int isNCHW) { const int paddingMode, const int isNCHW, const int wFormat) {
// weights [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW] // mkl support weights in [oC, iC, kH, kW] format only
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
@ -53,8 +53,8 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
dnnl::memory::dims dilation = { dH-1, dW-1}; dnnl::memory::dims dilation = { dH-1, dW-1};
auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; auto xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw;
dnnl::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iH, iW};
dnnl::memory::dims wDims = {oC, iC, kH, kW}; dnnl::memory::dims wDims = {oC, iC, kH, kW};
@ -66,17 +66,29 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(input, 4, x_user_md); mkldnnUtils::setBlockStrides(input, x_user_md);
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); uint i0, i1, i2, i3;
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); if(0 == wFormat) {
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
}
else if(1 == wFormat) {
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
}
else {
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
}
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
}
// bias // bias
dnnl::memory::desc b_mkl_md; dnnl::memory::desc b_mkl_md;
@ -85,9 +97,8 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
// output // output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(output, z_user_md);
mkldnnUtils::setBlockStrides(output, 4, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -103,10 +114,10 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights // weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// bias // bias
if(bias != nullptr) { if(bias != nullptr) {
@ -135,13 +146,13 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO,
NDArray *gradI, NDArray *gradW, NDArray *gradB, NDArray *gradI, NDArray *gradW, NDArray *gradB,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
const int paddingMode, const int isNCHW) { const int paddingMode, const int isNCHW, const int wFormat) {
// weights/gradW [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW] // mkl support weights/gradW in [oC, iC, kH, kW] format only
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
@ -150,8 +161,8 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
dnnl::memory::dims dilation = { dH-1, dW-1}; dnnl::memory::dims dilation = { dH-1, dW-1};
auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; auto xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw;
dnnl::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iH, iW};
dnnl::memory::dims wDims = {oC, iC, kH, kW}; dnnl::memory::dims wDims = {oC, iC, kH, kW};
@ -163,36 +174,60 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(input, 4, x_user_md); mkldnnUtils::setBlockStrides(input, x_user_md);
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); uint i0, i1, i2, i3;
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); if(0 == wFormat) {
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
}
else if(1 == wFormat) {
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
}
else {
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
}
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
}
// gradO // gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
// gradI // gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
// gradW // gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) {
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(2); uint i0, i1, i2, i3;
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); if(0 == wFormat) {
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
}
else if(1 == wFormat) {
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
}
else {
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
}
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
}
// gradB // gradB
dnnl::memory::desc gradB_mkl_md; dnnl::memory::desc gradB_mkl_md;
@ -221,10 +256,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights // weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// gradO // gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
@ -489,7 +524,7 @@ static void conv2dBpMKLDNN(sd::graph::Context &block,
PLATFORM_IMPL(conv2d, ENGINE_CPU) { PLATFORM_IMPL(conv2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
@ -500,24 +535,25 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) {
int pW = INT_ARG(5); // paddings width int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
return Status::OK(); return Status::OK();
} }
@ -536,12 +572,12 @@ PLATFORM_CHECK(conv2d, ENGINE_CPU) {
PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC] always auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
int kH = INT_ARG(0); // filter(kernel) height int kH = INT_ARG(0); // filter(kernel) height
@ -554,10 +590,11 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
int trueoH, trueoW; // true output height, width int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
@ -566,13 +603,13 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias) if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D_BP MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D_BP MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
return Status::OK(); return Status::OK();
} }

View File

@ -40,13 +40,13 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
const int sD, const int sH, const int sW, const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW, const int pD, const int pH, const int pW,
const int dD, const int dH, const int dW, const int dD, const int dH, const int dW,
const int paddingMode, const int isNCDHW) { const int paddingMode, const int isNCDHW, const int wFormat) {
// weights [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW] // mkl support weights in [oC, iC, kD, kH, kW] format only
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
// const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
@ -56,8 +56,8 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW};
dnnl::memory::dims dilation = {dD-1, dH-1, dW-1}; dnnl::memory::dims dilation = {dD-1, dH-1, dW-1};
auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw;
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
@ -69,18 +69,30 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(input, 5, x_user_md); mkldnnUtils::setBlockStrides(input, x_user_md);
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); uint i0, i1, i2, i3, i4;
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); if(0 == wFormat) {
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); }
else if(1 == wFormat) {
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
}
else {
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
}
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
}
// bias // bias
dnnl::memory::desc b_mkl_md; dnnl::memory::desc b_mkl_md;
@ -89,8 +101,8 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
// output // output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(output, 5, z_user_md); mkldnnUtils::setBlockStrides(output, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -106,10 +118,10 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights // weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// bias // bias
if(bias != nullptr) { if(bias != nullptr) {
@ -140,13 +152,13 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
const int sD, const int sH, const int sW, const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW, const int pD, const int pH, const int pW,
const int dD, const int dH, const int dW, const int dD, const int dH, const int dW,
const int paddingMode, const int isNCDHW) { const int paddingMode, const int isNCDHW, const int wFormat) {
// weights/gradW [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW] // mkl support weights/gradW in [oC, iC, kD, kH, kW] format only
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
// const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
@ -156,8 +168,8 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW};
dnnl::memory::dims dilation = {dD-1, dH-1, dW-1}; dnnl::memory::dims dilation = {dD-1, dH-1, dW-1};
auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw;
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
@ -169,40 +181,64 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(input, 5, x_user_md); mkldnnUtils::setBlockStrides(input, x_user_md);
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); uint i0, i1, i2, i3, i4;
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); if(0 == wFormat) {
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); }
else if(1 == wFormat) {
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
}
else {
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
}
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
}
// gradO // gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(gradO, 5, gradO_user_md); mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
// gradI // gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(gradI, 5, gradI_user_md); mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
// gradW // gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) {
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3); uint i0, i1, i2, i3, i4;
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); if(0 == wFormat) {
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2); }
else if(1 == wFormat) {
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
}
else {
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
}
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4);
}
// gradB // gradB
dnnl::memory::desc gradB_mkl_md; dnnl::memory::desc gradB_mkl_md;
@ -231,10 +267,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights // weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// gradO // gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
@ -486,7 +522,7 @@ static void conv3dBpMKLDNN(sd::graph::Context &block,
PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
@ -507,12 +543,13 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
int dW = INT_ARG(11); // dilations width int dW = INT_ARG(11); // dilations width
int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC]
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
@ -520,7 +557,7 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
if (paddingMode) // SAME if (paddingMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW); conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, wFormat);
return Status::OK(); return Status::OK();
} }
@ -538,12 +575,12 @@ PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_NULLIFIED(1); // [kD, kH, kW, iC, oC] always auto gradW = OUTPUT_NULLIFIED(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
@ -564,10 +601,11 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
int dW = INT_ARG(11); // dilations width int dW = INT_ARG(11); // dilations width
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC]
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
if(paddingMode) // SAME if(paddingMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
@ -576,26 +614,26 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2});
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW); conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, wFormat);
return Status::OK(); return Status::OK();
} }
PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) { PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
return block.isUseMKLDNN() && return block.isUseMKLDNN() &&
sd::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB}); sd::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB});

View File

@ -34,19 +34,30 @@ namespace platforms {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
const int paddingMode, const bool isNCHW) { const int paddingMode, const bool isNCHW, const int wFormat) {
// weights [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation // mkl supports weights format [oC, iC, kH, kW] only
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
dnnl::memory::dims strides = { sH, sW }; dnnl::memory::dims strides = { sH, sW };
dnnl::memory::dims padding = { pH, pW }; dnnl::memory::dims padding = { pH, pW };
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dH-1, dW-1 }; dnnl::memory::dims dilation = { dH-1, dW-1 };
uint i0, i1, i2, i3;
if(0 == wFormat) {
i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
}
else if(1 == wFormat) {
i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
}
else {
i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
}
// input type // input type
dnnl::memory::data_type xType; dnnl::memory::data_type xType;
if(input->dataType() == DataType::FLOAT32) if(input->dataType() == DataType::FLOAT32)
@ -76,8 +87,8 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
else else
zType = dnnl::memory::data_type::s32; zType = dnnl::memory::data_type::s32;
dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw;
dnnl::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iH, iW};
dnnl::memory::dims wDims = {oC, iC, kH, kW}; dnnl::memory::dims wDims = {oC, iC, kH, kW};
@ -87,17 +98,17 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
mkldnnUtils::setBlockStrides(input, 4, x_user_md); mkldnnUtils::setBlockStrides(input, x_user_md);
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW] w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
// bias // bias
dnnl::memory::desc b_mkl_md; dnnl::memory::desc b_mkl_md;
@ -106,8 +117,8 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
// output // output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl);
mkldnnUtils::setBlockStrides(output, 4, z_user_md); mkldnnUtils::setBlockStrides(output, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -124,10 +135,10 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights // weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// bias // bias
if(bias != nullptr) { if(bias != nullptr) {
@ -156,19 +167,30 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
const int paddingMode, const bool isNCHW) { const int paddingMode, const bool isNCHW, const int wFormat) {
// weights and gradW [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation // mkl supports weights/gradW in [oC, iC, kH, kW] format only
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
dnnl::memory::dims strides = { sH, sW }; dnnl::memory::dims strides = { sH, sW };
dnnl::memory::dims padding = { pH, pW }; dnnl::memory::dims padding = { pH, pW };
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dH-1, dW-1 }; dnnl::memory::dims dilation = { dH-1, dW-1 };
uint i0, i1, i2, i3;
if(0 == wFormat) {
i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
}
else if(1 == wFormat) {
i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
}
else {
i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
}
// input type // input type
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// weights type // weights type
@ -182,8 +204,8 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
// gradB type // gradB type
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw;
dnnl::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iH, iW};
dnnl::memory::dims wDims = {oC, iC, kH, kW}; dnnl::memory::dims wDims = {oC, iC, kH, kW};
@ -193,36 +215,36 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
mkldnnUtils::setBlockStrides(input, 4, x_user_md); mkldnnUtils::setBlockStrides(input, x_user_md);
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW] w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
// gradO // gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
// gradI // gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
// gradW // gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl);
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW] gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3); gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
// gradB // gradB
dnnl::memory::desc gradB_mkl_md; dnnl::memory::desc gradB_mkl_md;
@ -251,10 +273,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights // weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// gradO // gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
@ -311,7 +333,7 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
PLATFORM_IMPL(deconv2d, ENGINE_CPU) { PLATFORM_IMPL(deconv2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
@ -327,14 +349,15 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) {
int pW = INT_ARG(5); // paddings width int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC);
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
@ -344,7 +367,7 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) {
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
} }
deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
return Status::OK(); return Status::OK();
} }
@ -377,12 +400,12 @@ PLATFORM_CHECK(deconv2d, ENGINE_CPU) {
PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
@ -398,18 +421,19 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) {
int pW = INT_ARG(5); // paddings width int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
int trueoH, trueoW; // true output height, width int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias) if(bias)
@ -420,19 +444,19 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) {
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
} }
deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
return Status::OK(); return Status::OK();
} }
PLATFORM_CHECK(deconv2d_bp, ENGINE_CPU) { PLATFORM_CHECK(deconv2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
int dH = INT_ARG(6); // dilations height int dH = INT_ARG(6); // dilations height

View File

@ -34,7 +34,7 @@ namespace platforms {
static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI, static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI,
const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW, const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
const bool isNCHW) { const bool isNCHW, const int wFormat) {
// gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format // gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC] // weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC]
@ -52,8 +52,8 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
// gradI type // gradI type
dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw;
dnnl::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iH, iW};
dnnl::memory::dims wDims = {oC, iC, kH, kW}; dnnl::memory::dims wDims = {oC, iC, kH, kW};
@ -66,7 +66,7 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
@ -75,13 +75,13 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
// gradO // gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
// gradI // gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -101,10 +101,10 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// weights // weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// gradO // gradO
mkldnnUtils::loadDataToMklStream(gradO, engine, stream, args, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
// gradI // gradI
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
@ -128,10 +128,10 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI) auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI)
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
@ -143,6 +143,7 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
const int rank = gradO->rankOf(); const int rank = gradO->rankOf();
@ -188,7 +189,7 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
// gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] // gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
// } // }
deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat);
// delete weights; // delete weights;

View File

@ -35,19 +35,30 @@ namespace platforms {
static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW,
const bool isNCDHW) { const bool isNCDHW, const int wFormat) {
// weights [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation // mkl supports weights in [oC, iC, kD, kH, kW] only
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
dnnl::memory::dims strides = { sD, sH, sW }; dnnl::memory::dims strides = { sD, sH, sW };
dnnl::memory::dims padding = { pD, pH, pW }; dnnl::memory::dims padding = { pD, pH, pW };
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 }; dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
uint i0, i1, i2, i3, i4;
if(0 == wFormat) {
i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
}
else if(1 == wFormat) {
i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
}
else {
i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
}
// input type // input type
dnnl::memory::data_type xType; dnnl::memory::data_type xType;
if(input->dataType() == DataType::FLOAT32) if(input->dataType() == DataType::FLOAT32)
@ -77,8 +88,8 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
else else
zType = dnnl::memory::data_type::s32; zType = dnnl::memory::data_type::s32;
dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; dnnl::memory::format_tag xFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw;
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
@ -88,18 +99,18 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
mkldnnUtils::setBlockStrides(input, 5, x_user_md); mkldnnUtils::setBlockStrides(input, x_user_md);
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4); w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
// bias // bias
dnnl::memory::desc b_mkl_md; dnnl::memory::desc b_mkl_md;
@ -108,8 +119,8 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
// output // output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl);
mkldnnUtils::setBlockStrides(output, 5, z_user_md); mkldnnUtils::setBlockStrides(output, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -126,10 +137,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights // weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// bias // bias
if(bias != nullptr) { if(bias != nullptr) {
@ -161,19 +172,30 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
const int sD, const int sH, const int sW, const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW, const int pD, const int pH, const int pW,
const int dD, const int dH, const int dW, const int dD, const int dH, const int dW,
const bool isNCDHW) { const bool isNCDHW, const int wFormat) {
// weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation // mkl supports weights/gradW in [oC, iC, kD, kH, kW] format only
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
dnnl::memory::dims strides = { sD, sH, sW }; dnnl::memory::dims strides = { sD, sH, sW };
dnnl::memory::dims padding = { pD, pH, pW }; dnnl::memory::dims padding = { pD, pH, pW };
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 }; dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
uint i0, i1, i2, i3, i4;
if(0 == wFormat) {
i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
}
else if(1 == wFormat) {
i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
}
else {
i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
}
// input type // input type
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// weights type // weights type
@ -187,8 +209,8 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
// gradB type // gradB type
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; dnnl::memory::format_tag xFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw;
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
@ -198,38 +220,38 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
mkldnnUtils::setBlockStrides(input, 5, x_user_md); mkldnnUtils::setBlockStrides(input, x_user_md);
// weights // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4); w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
// gradO // gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
mkldnnUtils::setBlockStrides(gradO, 5, gradO_user_md); mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
// gradI // gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
mkldnnUtils::setBlockStrides(gradI, 5, gradI_user_md); mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
// gradW // gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat); dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl);
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(4); gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2); gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4);
// gradB // gradB
dnnl::memory::desc gradB_mkl_md; dnnl::memory::desc gradB_mkl_md;
@ -259,10 +281,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights // weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// gradO // gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
@ -319,7 +341,7 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
PLATFORM_IMPL(deconv3d, ENGINE_CPU) { PLATFORM_IMPL(deconv3d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
@ -341,12 +363,13 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) {
int dW = INT_ARG(11); // dilations width int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC);
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
@ -356,7 +379,7 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) {
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
} }
deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, wFormat);
return Status::OK(); return Status::OK();
} }
@ -390,12 +413,12 @@ PLATFORM_CHECK(deconv3d, ENGINE_CPU) {
PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
@ -416,17 +439,18 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) {
int dH = INT_ARG(10); // dilations height int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
int trueoD, trueoH, trueoW; // true output height, width int trueoD, trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias) if(bias)
@ -435,7 +459,7 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) {
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, wFormat);
return Status::OK(); return Status::OK();
} }
@ -443,12 +467,12 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) {
PLATFORM_CHECK(deconv3d_bp, ENGINE_CPU) { PLATFORM_CHECK(deconv3d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
int dD = INT_ARG(9); // dilations depth int dD = INT_ARG(9); // dilations depth

View File

@ -35,19 +35,19 @@ namespace platforms {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
const int paddingMode, const bool isNCHW) { const int paddingMode, const bool isNCHW, const int wFormat) {
// mkl supports only following case: mC = 1, oC = iC // mkl supports only following case: mC = 1, oC = iC
// input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given
// weights [kH, kW, iC, mC], mkl doesn't support this format, so we'll make permute // weights {iC, mC, 1, kH, kW}
// bias [oC], may be nullptr // bias [oC], may be nullptr
// output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc
// oC = iC*mC // oC = iC*mC
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weights->sizeAt(indWmC); // channels multiplier mC = weights->sizeAt(indWmC); // channels multiplier
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
@ -57,6 +57,17 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
dnnl::memory::dims dilation = { dH-1, dW-1}; dnnl::memory::dims dilation = { dH-1, dW-1};
uint i0, i1, i2, i3;
if(0 == wFormat) {
i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]
}
else if(1 == wFormat) {
i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW]
}
else {
i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW]
}
// input type // input type
dnnl::memory::data_type xType; dnnl::memory::data_type xType;
if(input->dataType() == DataType::FLOAT32) if(input->dataType() == DataType::FLOAT32)
@ -86,8 +97,8 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
else else
zType = dnnl::memory::data_type::s32; zType = dnnl::memory::data_type::s32;
dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; dnnl::memory::format_tag xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw; dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw;
dnnl::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iH, iW};
dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; dnnl::memory::dims wDims = {iC, mC, 1, kH, kW};
@ -97,18 +108,18 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl);
mkldnnUtils::setBlockStrides(input, 4, x_user_md); mkldnnUtils::setBlockStrides(input, x_user_md);
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // permute w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); // permute
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = 0; w_user_md.data.format_desc.blocking.strides[2] = 0;
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(0); w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(1); w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3);
// bias // bias
dnnl::memory::desc b_mkl_md; dnnl::memory::desc b_mkl_md;
@ -117,8 +128,8 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
// output // output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFormatMkl);
mkldnnUtils::setBlockStrides(output, 4, z_user_md); mkldnnUtils::setBlockStrides(output, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -135,10 +146,10 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights // weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// bias // bias
if(bias != nullptr) { if(bias != nullptr) {
@ -166,19 +177,19 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
const int paddingMode, const bool isNCHW) { const int paddingMode, const bool isNCHW, const int wFormat) {
// mkl supports only following case: mC = 1, oC = iC // mkl supports only following case: mC = 1, oC = iC
// input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given
// weights, gradW [kH, kW, iC, mC], mkl doesn't support this format, so we'll make permute // weights/gradW {iC, mC, 1, kH, kW}
// gradB [oC], may be nullptr // gradB [oC], may be nullptr
// gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc
// oC = iC*mC // oC = iC*mC
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weights->sizeAt(indWmC); mC = weights->sizeAt(indWmC);
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
@ -188,6 +199,17 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
dnnl::memory::dims dilation = { dH-1, dW-1}; dnnl::memory::dims dilation = { dH-1, dW-1};
uint i0, i1, i2, i3;
if(0 == wFormat) {
i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]
}
else if(1 == wFormat) {
i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW]
}
else {
i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW]
}
// input type // input type
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// weights type // weights type
@ -201,8 +223,8 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
// gradB type // gradB type
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; dnnl::memory::format_tag xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw; dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw;
dnnl::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims xDims = {bS, iC, iH, iW};
dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; dnnl::memory::dims wDims = {iC, mC, 1, kH, kW};
@ -212,38 +234,38 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
// input // input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl);
mkldnnUtils::setBlockStrides(input, 4, x_user_md); mkldnnUtils::setBlockStrides(input, x_user_md);
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; // weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // permute w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); // permute
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = 0; w_user_md.data.format_desc.blocking.strides[2] = 0;
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(0); w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(1); w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3);
// gradO // gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFrmat); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFormatMkl);
mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
// gradI // gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFrmat); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFormatMkl);
mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
// gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; // gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl);
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(2); // permute gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); // permute
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3); gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
gradW_user_md.data.format_desc.blocking.strides[2] = 0; gradW_user_md.data.format_desc.blocking.strides[2] = 0;
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(0); gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i2);
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(1); gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i3);
// gradB // gradB
dnnl::memory::desc gradB_mkl_md; dnnl::memory::desc gradB_mkl_md;
@ -272,10 +294,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights // weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// gradO // gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
@ -332,7 +354,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) { PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
@ -347,21 +369,22 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weights->sizeAt(indWmC); // channels multiplier mC = weights->sizeAt(indWmC); // channels multiplier
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D MKL OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D MKL OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC);
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
depthwiseConv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); depthwiseConv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
return Status::OK(); return Status::OK();
} }
@ -394,12 +417,12 @@ PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) {
PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
@ -416,10 +439,11 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weights->sizeAt(indWmC); // channels multiplier mC = weights->sizeAt(indWmC); // channels multiplier
int trueoH, trueoW; // correct output height, width int trueoH, trueoW; // correct output height, width
@ -428,13 +452,13 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC);
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias) if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
return Status::OK(); return Status::OK();
} }
@ -443,12 +467,12 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) { PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
const DataType xType = input->dataType(); const DataType xType = input->dataType();

View File

@ -272,13 +272,13 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
// provide memory and check whether reorder is required // provide memory and check whether reorder is required
// x // x
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, lstm_prim_desc.src_layer_desc(), DNNL_ARG_SRC_LAYER); mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]);
// wx // wx
mkldnnUtils::loadDataToMklStream(Wx, engine, stream, args, wx_user_md, lstm_prim_desc.weights_layer_desc(), DNNL_ARG_WEIGHTS_LAYER); mkldnnUtils::loadDataToMklStream(Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]);
// wr // wr
mkldnnUtils::loadDataToMklStream(Wr, engine, stream, args, wr_user_md, lstm_prim_desc.weights_iter_desc(), DNNL_ARG_WEIGHTS_ITER); mkldnnUtils::loadDataToMklStream(Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]);
// h // h
auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer()); auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer());
@ -288,17 +288,17 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
// b // b
if(b) { if(b) {
mkldnnUtils::loadDataToMklStream(b, engine, stream, args, b_user_md, lstm_prim_desc.bias_desc(), DNNL_ARG_BIAS); mkldnnUtils::loadDataToMklStream(b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]);
} }
// hI // hI
if(hI) { if(hI) {
mkldnnUtils::loadDataToMklStream(hI, engine, stream, args, hI_user_md, lstm_prim_desc.src_iter_desc(), DNNL_ARG_SRC_ITER); mkldnnUtils::loadDataToMklStream(hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]);
} }
// cI // cI
if(cI) { if(cI) {
mkldnnUtils::loadDataToMklStream(cI, engine, stream, args, cI_user_md, lstm_prim_desc.src_iter_c_desc(), DNNL_ARG_SRC_ITER_C); mkldnnUtils::loadDataToMklStream(cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]);
} }
bool hLReorder(false), cLReorder(false); bool hLReorder(false), cLReorder(false);

View File

@ -163,7 +163,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(xTR, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(xTR, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
/* /*
auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->getBuffer()); auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->getBuffer());
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
@ -173,7 +173,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
args[DNNL_ARG_SRC] = x_mkl_mem; args[DNNL_ARG_SRC] = x_mkl_mem;
*/ */
// y // y
mkldnnUtils::loadDataToMklStream(yTR, engine, stream, args, y_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); mkldnnUtils::loadDataToMklStream(yTR, engine, stream, y_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
/* /*
auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->getBuffer()); auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->getBuffer());
const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc(); const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc();

View File

@ -60,7 +60,7 @@ PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
if (paddingMode) if (paddingMode)
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
@ -102,7 +102,7 @@ PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1});
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());

View File

@ -60,7 +60,7 @@ PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
if(paddingMode) // SAME if(paddingMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
@ -107,7 +107,7 @@ PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());

View File

@ -56,25 +56,27 @@ dnnl::memory::format_tag getFormat(const int rank){
} }
return dnnl::memory::format_tag::a; // 1 == dataSetRank return dnnl::memory::format_tag::a; // 1 == dataSetRank
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
void setBlockStrides(const NDArray* array, const int rank, dnnl::memory::desc& mklMd){ void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd){
if (array->ews() != 1 || array->ordering() != 'c') {
mklMd.data.format_kind = dnnl_blocked; // overrides format if (array->ews() != 1 || array->ordering() != 'c') {
for (auto i = 0; i < rank; ++i) { mklMd.data.format_kind = dnnl_blocked; // overrides format
mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i); for (auto i = 0; i < array->rankOf(); ++i) {
} mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i);
} }
}
} }
//////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////
void loadDataToMklStream(const NDArray* array, dnnl::engine& engine, dnnl::stream& stream, void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md,
std::unordered_map<int, dnnl::memory>& args, dnnl::memory::desc& user_md, dnnl::memory::desc primitive_md, int DNNL_ARG ){ dnnl::memory& arg) {
auto user_mem = dnnl::memory(user_md, engine, array->getBuffer()); auto user_mem = dnnl::memory(user_md, engine, array->getBuffer());
const bool bReorder = primitive_md != user_mem.get_desc(); const bool bReorder = primitive_md != user_mem.get_desc();
auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem; auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem;
if (bReorder) if (bReorder)
dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem); dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem);
args[DNNL_ARG] = mkl_mem; arg = mkl_mem;
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -95,7 +97,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output,
if(rank == 4) { // 2d if(rank == 4) { // 2d
ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
strides = { sH, sW }; strides = { sH, sW };
kernel = { kH, kW }; kernel = { kH, kW };
@ -108,7 +110,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output,
} }
else { // 3d else { // 3d
ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH);
strides = { sD, sH, sW }; strides = { sD, sH, sW };
kernel = { kD, kH, kW }; kernel = { kD, kH, kW };
@ -162,7 +164,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output,
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// output // output
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer()); auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
@ -199,7 +201,7 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
if(rank == 4) { // 2d if(rank == 4) { // 2d
ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
strides = { sH, sW }; strides = { sH, sW };
kernel = { kH, kW }; kernel = { kH, kW };
@ -212,7 +214,7 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
} }
else { // 3d else { // 3d
ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH);
strides = { sD, sH, sW }; strides = { sD, sH, sW };
kernel = { kD, kH, kW }; kernel = { kD, kH, kW };
@ -280,7 +282,7 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
std::unordered_map<int, dnnl::memory> args; std::unordered_map<int, dnnl::memory> args;
// gradO // gradO
mkldnnUtils::loadDataToMklStream(gradO, engine, stream, args, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
// gradI // gradI
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
@ -291,7 +293,7 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
if(mode == algorithm::pooling_max) { if(mode == algorithm::pooling_max) {
// input // input
mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// z // z
auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine); auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine);

View File

@ -131,7 +131,7 @@ namespace sd {
* @param reference to memory descriptor * @param reference to memory descriptor
* @return memory format * @return memory format
*/ */
void setBlockStrides(const NDArray* array, const int rank, dnnl::memory::desc& mklMd); void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd);
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
/** /**
* This function load and reorder user memory to mkl * This function load and reorder user memory to mkl
@ -143,8 +143,8 @@ namespace sd {
* @param primitive memory descriptor * @param primitive memory descriptor
* @param dnnl arg activation enumerator * @param dnnl arg activation enumerator
*/ */
void loadDataToMklStream(const NDArray* array, dnnl::engine& engine, dnnl::stream& stream, void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md,
std::unordered_map<int, dnnl::memory>& args, dnnl::memory::desc& user_md, dnnl::memory::desc primitive_md, int DNNL_ARG); dnnl::memory& arg);
/** /**
* Utility methods for MKLDNN * Utility methods for MKLDNN

View File

@ -55,12 +55,12 @@ namespace sd {
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
mkldnnUtils::setBlockStrides(x, xRank, x_user_md); mkldnnUtils::setBlockStrides(x, x_user_md);
// z // z
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format); dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format);
mkldnnUtils::setBlockStrides(z, xRank, z_user_md); mkldnnUtils::setBlockStrides(z, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -80,7 +80,7 @@ namespace sd {
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// z // z
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
@ -156,19 +156,19 @@ namespace sd {
// x // x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(x, xRank, x_user_md); mkldnnUtils::setBlockStrides(x, x_user_md);
// dLdx // dLdx
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(dLdx, xRank, dLdx_user_md); mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
// todo if mkl does not support broadcast we can remove this // todo if mkl does not support broadcast we can remove this
format = mkldnnUtils::getFormat(dLdzRank); format = mkldnnUtils::getFormat(dLdzRank);
// dLdz // dLdz
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(dLdz, dLdzRank, dLdz_user_md); mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -188,7 +188,7 @@ namespace sd {
// provide memory buffers and check whether reorder is required for forward // provide memory buffers and check whether reorder is required for forward
// input // input
mkldnnUtils::loadDataToMklStream(x, engine, stream, argsff, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_SRC]);
// dLdx // dLdx
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer()); auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer());
@ -200,7 +200,7 @@ namespace sd {
argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
argsbp[DNNL_ARG_DST] = dLdx_mkl_mem; argsbp[DNNL_ARG_DST] = dLdx_mkl_mem;
// dLdz // dLdz
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, argsbp, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), argsbp[DNNL_ARG_DIFF_DST]);
// run calculations forward // run calculations forward
dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff); dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff);

View File

@ -44,12 +44,12 @@ namespace sd {
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(x, xRank, x_user_md); mkldnnUtils::setBlockStrides(x, x_user_md);
// z // z
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(z, xRank, z_user_md); mkldnnUtils::setBlockStrides(z, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -68,7 +68,7 @@ namespace sd {
// provide memory buffers and check whether reorder is required // provide memory buffers and check whether reorder is required
// input // input
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// z // z
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
@ -132,17 +132,17 @@ namespace sd {
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(x, xRank, x_user_md); mkldnnUtils::setBlockStrides(x, x_user_md);
// dLdz // dLdz
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(dLdz, xRank, dLdz_user_md); mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
// dLdx // dLdx
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(dLdx, xRank, dLdx_user_md); mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -162,10 +162,10 @@ namespace sd {
// provide memory buffers and check whether reorder is required for forward // provide memory buffers and check whether reorder is required for forward
// input // input
mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// dLdz // dLdz
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, args, dLdz_user_md, op_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
// dLdx // dLdx
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer()); auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer());

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long