Shyrma weights format (#329)
* - start to introduce additional weights formats into conv2d ops Signed-off-by: Yurii <iuriish@yahoo.com> * - provide weights format variety in backprop conv2d and deconv2d ops, testing and fixing bugs Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to recover kernels sizes in deconv2d_bp test Signed-off-by: Yurii <iuriish@yahoo.com> * - built in weights format in depthwise conv 2d op Signed-off-by: Yurii <iuriish@yahoo.com> * - provide new weights formats in mkl dnn conv ops Signed-off-by: Yurii <iuriish@yahoo.com> * - provide new weights formats in cuda conv helpers Signed-off-by: Yurii <iuriish@yahoo.com> * - working with new weights format in cudnn conv api Signed-off-by: Yurii <iuriish@yahoo.com> * - take into account order of arrays in cudnn tensor descriptions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide new weights formats in cpu conv3d (ff/bp) Signed-off-by: Yurii <iuriish@yahoo.com> * - provide new weights formats in cpu deconv3d (ff/bp) Signed-off-by: Yurii <iuriish@yahoo.com> * - provide new weights formats in conv3d ops (ff/bp) based on mkl api Signed-off-by: Yurii <iuriish@yahoo.com> * - provide new weights formats in conv3d ops (ff/bp) based on cudnn api Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts 2 Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									5dae4069cf
								
							
						
					
					
						commit
						e700b59f80
					
				| @ -4076,7 +4076,7 @@ INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShap | |||||||
| 
 | 
 | ||||||
|     // *** FIRST STAGE - exclude unity dimensions from oldShapeInfo and newShapeInfo (if such are present of course), since they don't affect on strides evaluation, however they complicate code
 |     // *** FIRST STAGE - exclude unity dimensions from oldShapeInfo and newShapeInfo (if such are present of course), since they don't affect on strides evaluation, however they complicate code
 | ||||||
| 
 | 
 | ||||||
|     // FIXME - indeed we don't need to allocate so large memory amount (2*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities)
 |     // FIXME - indeed we don't need to allocate so large memory amount (4*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities)
 | ||||||
|     Nd4jLong tempBuffer[4*MAX_RANK]; |     Nd4jLong tempBuffer[4*MAX_RANK]; | ||||||
|     Nd4jLong *oldShape = tempBuffer, *newShape = tempBuffer + 2*MAX_RANK, *oldStrides,  *newStrides; |     Nd4jLong *oldShape = tempBuffer, *newShape = tempBuffer + 2*MAX_RANK, *oldStrides,  *newStrides; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -34,7 +34,7 @@ namespace ops  { | |||||||
| CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { | CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
 |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kW, iC, oC] always
 |     auto weights = INPUT_VARIABLE(1);                                    // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC]
 | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 | ||||||
| 
 | 
 | ||||||
|     auto output  = OUTPUT_NULLIFIED(0);                                   // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW)
 |     auto output  = OUTPUT_NULLIFIED(0);                                   // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW)
 | ||||||
| @ -45,12 +45,13 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { | |||||||
|     int dW = INT_ARG(3);                                                        // dilations width
 |     int dW = INT_ARG(3);                                                        // dilations width
 | ||||||
|     int paddingMode = INT_ARG(4);                                               // 0-VALID, 1-SAME, 2-CAUSAL
 |     int paddingMode = INT_ARG(4);                                               // 0-VALID, 1-SAME, 2-CAUSAL
 | ||||||
|     int isNCW       = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1;      // INT_ARG(4): 0-NCW,  1-NWC
 |     int isNCW       = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1;      // INT_ARG(4): 0-NCW,  1-NWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0;           // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     const int rank = 3; |     const int rank = 3; | ||||||
|     REQUIRE_TRUE(input->rankOf()   == rank, 0, "CUSTOM CONV1D OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); |     REQUIRE_TRUE(input->rankOf()   == rank, 0, "CUSTOM CONV1D OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); | ||||||
|     REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf()); |     REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf()); | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiW, indWoC(2); |     int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); | ||||||
|     if(!isNCW) { |     if(!isNCW) { | ||||||
|         indIOioC = 2; indIiW = 1; |         indIOioC = 2; indIiW = 1; | ||||||
|     } |     } | ||||||
| @ -63,7 +64,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { | |||||||
|     int iC = input->sizeAt(indIOioC);                 // input channels
 |     int iC = input->sizeAt(indIOioC);                 // input channels
 | ||||||
|     int oC = weights->sizeAt(indWoC);                 // output channels
 |     int oC = weights->sizeAt(indWoC);                 // output channels
 | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = 0 == wFormat ? std::vector<Nd4jLong>({kW, iC, oC}) : (1 == wFormat ? std::vector<Nd4jLong>({oC, iC, kW}) : std::vector<Nd4jLong>({oC, kW, iC})); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| @ -83,11 +84,11 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { | |||||||
|     auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)});   // [kW, iC, oC] -> [1, kW, iC, oC]
 |     auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)});   // [kW, iC, oC] -> [1, kW, iC, oC]
 | ||||||
| 
 | 
 | ||||||
|     sd::ops::conv2d conv2d; |     sd::ops::conv2d conv2d; | ||||||
|     const Nd4jStatus status = conv2d.execute({&inputReshaped, &weightsReshaped, bias}, {&outputReshaped}, {}, {1,kW,  1,sW,  0,pW,  1,dW,  paddingMode,  !isNCW}, {}); |     const Nd4jStatus status = conv2d.execute({&inputReshaped, &weightsReshaped, bias}, {&outputReshaped}, {}, {1,kW,  1,sW,  0,pW,  1,dW,  paddingMode, !isNCW, wFormat}, {}); | ||||||
|     if (status != ND4J_STATUS_OK) |     if (status != ND4J_STATUS_OK) | ||||||
|         return status; |         return status; | ||||||
| 
 | 
 | ||||||
|     // ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW,  1,sW,  0,pW,  1,dW,  paddingMode,  isNCW);
 |     // ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW,  1,sW,  0,pW,  1,dW,  paddingMode, isNCW, wFormat);
 | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -105,8 +106,9 @@ DECLARE_SHAPE_FN(conv1d) { | |||||||
|     int dW = INT_ARG(3);                                                        // dilations width
 |     int dW = INT_ARG(3);                                                        // dilations width
 | ||||||
|     int paddingMode = INT_ARG(4);                                               // 0-VALID, 1-SAME
 |     int paddingMode = INT_ARG(4);                                               // 0-VALID, 1-SAME
 | ||||||
|     int isNCW  = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1;           // INT_ARG(4): 1-NWC, 0-NCW
 |     int isNCW  = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1;           // INT_ARG(4): 1-NWC, 0-NCW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0;           // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiW, indWoC(2); |     int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); | ||||||
|     if(!isNCW) { |     if(!isNCW) { | ||||||
|         indIOioC = 2; indIiW = 1; |         indIOioC = 2; indIiW = 1; | ||||||
|     } |     } | ||||||
| @ -123,7 +125,7 @@ DECLARE_SHAPE_FN(conv1d) { | |||||||
|     int iC = inputShapeInfo[indIOioC+1];                   // input channels
 |     int iC = inputShapeInfo[indIOioC+1];                   // input channels
 | ||||||
|     int oC = weightsShapeInfo[indWoC+1];                 // output channels
 |     int oC = weightsShapeInfo[indWoC+1];                 // output channels
 | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = 0 == wFormat ? std::vector<Nd4jLong>({kW, iC, oC}) : (1 == wFormat ? std::vector<Nd4jLong>({oC, iC, kW}) : std::vector<Nd4jLong>({oC, kW, iC})); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
|     if (biasShapeInfo) |     if (biasShapeInfo) | ||||||
|         REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); |         REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); | ||||||
| @ -163,12 +165,12 @@ DECLARE_TYPES(conv1d) { | |||||||
| CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { | CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
 |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kW, iC, oC] always
 |     auto weights = INPUT_VARIABLE(1);                                                // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC]
 | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next
 |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_NULLIFIED(0);                                                 // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon
 |     auto gradI = OUTPUT_NULLIFIED(0);                                                 // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon
 | ||||||
|     auto gradW = OUTPUT_NULLIFIED(1);                                                 // [kW, iC, oC] always
 |     auto gradW = OUTPUT_NULLIFIED(1);                                                 // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
|     int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width
 |     int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width
 | ||||||
| @ -177,12 +179,14 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { | |||||||
|     int dW = INT_ARG(3);                                                        // dilations width
 |     int dW = INT_ARG(3);                                                        // dilations width
 | ||||||
|     int paddingMode = INT_ARG(4);                                               // 0-VALID, 1-SAME, 2-CAUSAL
 |     int paddingMode = INT_ARG(4);                                               // 0-VALID, 1-SAME, 2-CAUSAL
 | ||||||
|     int isNCW  = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1;           // INT_ARG(4): 1-NWC, 0-NCW
 |     int isNCW  = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1;           // INT_ARG(4): 1-NWC, 0-NCW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0;           // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     const int rank = 3; |     const int rank = 3; | ||||||
|     REQUIRE_TRUE(input->rankOf()   == rank, 0, "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); |     REQUIRE_TRUE(input->rankOf()   == rank, 0, "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); | ||||||
|     REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf()); |     REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf()); | ||||||
|     REQUIRE_TRUE(gradO->rankOf()   == rank, 0, "CUSTOM CONV1D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradO->rankOf()); |     REQUIRE_TRUE(gradO->rankOf()   == rank, 0, "CUSTOM CONV1D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradO->rankOf()); | ||||||
|     int indIOioC, indIiW, indWoC(2); | 
 | ||||||
|  |     int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); | ||||||
|     if(!isNCW) { |     if(!isNCW) { | ||||||
|         indIOioC = 2; indIiW = 1; |         indIOioC = 2; indIiW = 1; | ||||||
|     } |     } | ||||||
| @ -199,7 +203,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { | |||||||
|     ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); |     ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW,  0,indIOioC,indIiW}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW,  0,indIOioC,indIiW}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = 0 == wFormat ? std::vector<Nd4jLong>({kW, iC, oC}) : (1 == wFormat ? std::vector<Nd4jLong>({oC, iC, kW}) : std::vector<Nd4jLong>({oC, kW, iC})); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if(bias) |     if(bias) | ||||||
| @ -222,11 +226,11 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { | |||||||
|     auto gradWReshaped   = gradW  ->reshape(gradW->ordering(),  {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC]
 |     auto gradWReshaped   = gradW  ->reshape(gradW->ordering(),  {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC]
 | ||||||
| 
 | 
 | ||||||
|     sd::ops::conv2d_bp conv2dBP; |     sd::ops::conv2d_bp conv2dBP; | ||||||
|     auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW,  1,sW,  0,pW,  1,dW,  paddingMode,  !isNCW}, {}); |     auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW,  1,sW,  0,pW,  1,dW,  paddingMode, !isNCW, wFormat}, {}); | ||||||
|     if (status != ND4J_STATUS_OK) |     if (status != ND4J_STATUS_OK) | ||||||
|         return status; |         return status; | ||||||
| 
 | 
 | ||||||
|     // ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW,  1,sW,  0,pW,  1,dW,  paddingMode,  isNCW);
 |     // ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW,  1,sW,  0,pW,  1,dW,  paddingMode, isNCW, wFormat);
 | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -235,7 +239,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { | |||||||
| DECLARE_SHAPE_FN(conv1d_bp) { | DECLARE_SHAPE_FN(conv1d_bp) { | ||||||
| 
 | 
 | ||||||
|     auto inputShapeInfo   = inputShape->at(0);                                               // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
 |     auto inputShapeInfo   = inputShape->at(0);                                               // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
 | ||||||
|     auto weightsShapeInfo = inputShape->at(1);                                               // [kW, iC, oC] always
 |     auto weightsShapeInfo = inputShape->at(1);                                               // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC]
 | ||||||
|     Nd4jLong* biasShapeInfo    = block.width() > 3 ? inputShape->at(2) : nullptr;            // [oC]
 |     Nd4jLong* biasShapeInfo    = block.width() > 3 ? inputShape->at(2) : nullptr;            // [oC]
 | ||||||
|     Nd4jLong* gradOShapeInfo   = block.width() > 3 ? inputShape->at(3) : inputShape->at(2);  // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next
 |     Nd4jLong* gradOShapeInfo   = block.width() > 3 ? inputShape->at(3) : inputShape->at(2);  // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
| @ -250,8 +254,9 @@ DECLARE_SHAPE_FN(conv1d_bp) { | |||||||
|     int dW = INT_ARG(3);                                                        // dilations width
 |     int dW = INT_ARG(3);                                                        // dilations width
 | ||||||
|     int paddingMode = INT_ARG(4);                                               // 0-VALID, 1-SAME
 |     int paddingMode = INT_ARG(4);                                               // 0-VALID, 1-SAME
 | ||||||
|     int isNCW  = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1;           // INT_ARG(4): 1-NWC, 0-NCW
 |     int isNCW  = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1;           // INT_ARG(4): 1-NWC, 0-NCW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0;           // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiW, indWoC(2); |     int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); | ||||||
|     if(!isNCW) { |     if(!isNCW) { | ||||||
|         indIOioC = 2; indIiW = 1; |         indIOioC = 2; indIiW = 1; | ||||||
|     } |     } | ||||||
| @ -268,7 +273,7 @@ DECLARE_SHAPE_FN(conv1d_bp) { | |||||||
|     ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); |     ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW,  0,indIOioC,indIiW}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW,  0,indIOioC,indIiW}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = 0 == wFormat ? std::vector<Nd4jLong>({kW, iC, oC}) : (1 == wFormat ? std::vector<Nd4jLong>({oC, iC, kW}) : std::vector<Nd4jLong>({oC, kW, iC})); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0,  "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0,  "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
|     if(biasShapeInfo) |     if(biasShapeInfo) | ||||||
|  | |||||||
| @ -37,7 +37,7 @@ namespace ops  { | |||||||
| CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { | CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, oC] always
 |     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 | ||||||
| 
 | 
 | ||||||
|     auto output  = OUTPUT_NULLIFIED(0);                                   // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 |     auto output  = OUTPUT_NULLIFIED(0);                                   // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 | ||||||
| @ -49,21 +49,22 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { | |||||||
|     int dH = INT_ARG(6);                                                        // dilations height
 |     int dH = INT_ARG(6);                                                        // dilations height
 | ||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     bool isNCHW    = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 |     int isNCHW  = block.getIArguments()->size() > 9  ? !INT_ARG(9) : 1;         // INT_ARG(9): 0-NCHW,  1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
 |     int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
 | ||||||
|     int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
 |     int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); |     ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -73,7 +74,7 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { | |||||||
| DECLARE_SHAPE_FN(conv2d) { | DECLARE_SHAPE_FN(conv2d) { | ||||||
| 
 | 
 | ||||||
|     auto inputShapeInfo   = inputShape->at(0);                                  // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     auto inputShapeInfo   = inputShape->at(0);                                  // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|     auto weightsShapeInfo = inputShape->at(1);                                  // [kH, kW, iC, oC] always
 |     auto weightsShapeInfo = inputShape->at(1);                                  // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto biasShapeInfo    = block.width() > 2 ? inputShape->at(2) : nullptr;    // [oC]
 |     auto biasShapeInfo    = block.width() > 2 ? inputShape->at(2) : nullptr;    // [oC]
 | ||||||
| 
 | 
 | ||||||
|     //output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 |     //output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 | ||||||
| @ -86,6 +87,7 @@ DECLARE_SHAPE_FN(conv2d) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) height
 |     int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) height
 | ||||||
|     int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1)); // filter(kernel) width
 |     int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1)); // filter(kernel) width
 | ||||||
| @ -95,7 +97,7 @@ DECLARE_SHAPE_FN(conv2d) { | |||||||
|     REQUIRE_TRUE(inputShapeInfo[0]   == rank, 0, "CUSTOM CONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); |     REQUIRE_TRUE(inputShapeInfo[0]   == rank, 0, "CUSTOM CONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); | ||||||
|     REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); |     REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiH, indWoC(3); |     int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0); | ||||||
|     if(!isNCHW) { |     if(!isNCHW) { | ||||||
|         indIOioC = 3; indIiH = 1; |         indIOioC = 3; indIiH = 1; | ||||||
|     } |     } | ||||||
| @ -109,7 +111,7 @@ DECLARE_SHAPE_FN(conv2d) { | |||||||
|     const int iC = inputShapeInfo[indIOioC+1];                   // input channels
 |     const int iC = inputShapeInfo[indIOioC+1];                   // input channels
 | ||||||
|     const int oC = weightsShapeInfo[indWoC+1];                   // output channels
 |     const int oC = weightsShapeInfo[indWoC+1];                   // output channels
 | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
|     if (biasShapeInfo) |     if (biasShapeInfo) | ||||||
|         REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); |         REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); | ||||||
| @ -157,12 +159,12 @@ DECLARE_SHAPE_FN(conv2d) { | |||||||
| CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { | CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC] always
 |     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_NULLIFIED(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 |     auto gradI = OUTPUT_NULLIFIED(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 | ||||||
|     auto gradW = OUTPUT_NULLIFIED(1);                                                 // [kH, kW, iC, oC] always
 |     auto gradW = OUTPUT_NULLIFIED(1);                                                 // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
|     int kH = INT_ARG(0);                                                        // filter(kernel) height
 |     int kH = INT_ARG(0);                                                        // filter(kernel) height
 | ||||||
| @ -175,6 +177,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(input->rankOf()   == 4, 0, "CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); |     REQUIRE_TRUE(input->rankOf()   == 4, 0, "CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); | ||||||
|     REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); |     REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); | ||||||
| @ -182,19 +185,19 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     int trueoH, trueoW;          // true output height, width
 |     int trueoH, trueoW;          // true output height, width
 | ||||||
|     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); |     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong>expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); |     std::vector<Nd4jLong>expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); | ||||||
|     std::vector<Nd4jLong>expectedWeightsShape = {kH, kW, iC, oC}; |     std::vector<Nd4jLong>expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if(bias) |     if(bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); |     ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -204,7 +207,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { | |||||||
| DECLARE_SHAPE_FN(conv2d_bp) { | DECLARE_SHAPE_FN(conv2d_bp) { | ||||||
| 
 | 
 | ||||||
|     auto inputShapeInfo   = inputShape->at(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     auto inputShapeInfo   = inputShape->at(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|     auto weightsShapeInfo = inputShape->at(1);                                                // [kH, kW, iC, oC] always
 |     auto weightsShapeInfo = inputShape->at(1);                                                // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto biasShapeInfo    = block.width() > 3 ? inputShape->at(2) : nullptr;                  // [oC]
 |     auto biasShapeInfo    = block.width() > 3 ? inputShape->at(2) : nullptr;                  // [oC]
 | ||||||
|     auto gradOShapeInfo   = block.width() > 3 ? inputShape->at(3) : inputShape->at(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 |     auto gradOShapeInfo   = block.width() > 3 ? inputShape->at(3) : inputShape->at(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
| @ -224,8 +227,9 @@ DECLARE_SHAPE_FN(conv2d_bp) { | |||||||
|     const int dW = INT_ARG(7);                                                        // dilations width
 |     const int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     const int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     const int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     const int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC
 |     const int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC
 | ||||||
|  |     const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiH, indOoH, indWoC(3); |     int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 3 : 0); | ||||||
|     if(!isNCHW) { |     if(!isNCHW) { | ||||||
|         indIOioC = 3; indIiH = 1; indOoH = 1; |         indIOioC = 3; indIiH = 1; indOoH = 1; | ||||||
|     } |     } | ||||||
| @ -243,7 +247,7 @@ DECLARE_SHAPE_FN(conv2d_bp) { | |||||||
|     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); |     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0,  "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0,  "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
|     if(biasShapeInfo) |     if(biasShapeInfo) | ||||||
| @ -264,7 +268,7 @@ DECLARE_SHAPE_FN(conv2d_bp) { | |||||||
| CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { | CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { | ||||||
| 
 | 
 | ||||||
|     auto gradIShape = INPUT_VARIABLE(0);                                                // [4]
 |     auto gradIShape = INPUT_VARIABLE(0);                                                // [4]
 | ||||||
|     auto weights    = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC] always
 |     auto weights    = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto gradO      = INPUT_VARIABLE(2);                                                // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 |     auto gradO      = INPUT_VARIABLE(2);                                                // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_NULLIFIED(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 |     auto gradI = OUTPUT_NULLIFIED(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 | ||||||
| @ -279,6 +283,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     const int rank = gradO->rankOf(); |     const int rank = gradO->rankOf(); | ||||||
| 
 | 
 | ||||||
| @ -295,17 +300,17 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     int trueoH, trueoW;          // true output height, width
 |     int trueoH, trueoW;          // true output height, width
 | ||||||
|     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); |     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); |     ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -321,7 +326,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { | |||||||
| DECLARE_SHAPE_FN(conv2d_input_bp) { | DECLARE_SHAPE_FN(conv2d_input_bp) { | ||||||
| 
 | 
 | ||||||
|     auto gradIShapeShapeInfo = inputShape->at(0);                                                // [4]
 |     auto gradIShapeShapeInfo = inputShape->at(0);                                                // [4]
 | ||||||
|     auto weightsShapeInfo    = inputShape->at(1);                                                // [kH, kW, iC, oC] always
 |     auto weightsShapeInfo    = inputShape->at(1);                                                // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto gradOShapeInfo      = inputShape->at(2);                                                // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 |     auto gradOShapeInfo      = inputShape->at(2);                                                // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     const int rank = 4; |     const int rank = 4; | ||||||
| @ -340,8 +345,9 @@ DECLARE_SHAPE_FN(conv2d_input_bp) { | |||||||
|     const int dW = INT_ARG(7);                                                        // dilations width
 |     const int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     const int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     const int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     const int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC
 |     const int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC
 | ||||||
|  |     const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiH, indWoC(3), indOoH; |     int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH; | ||||||
|     if(!isNCHW) { |     if(!isNCHW) { | ||||||
|         indIOioC = 3; indIiH = 1; indOoH = 1; |         indIOioC = 3; indIiH = 1; indOoH = 1; | ||||||
|     } |     } | ||||||
| @ -361,7 +367,7 @@ DECLARE_SHAPE_FN(conv2d_input_bp) { | |||||||
|     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); |     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0,  "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0,  "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -32,7 +32,7 @@ namespace ops  { | |||||||
| 
 | 
 | ||||||
| CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { | CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kD, kH, kW, iC, oC] always
 |     auto weights = INPUT_VARIABLE(1);                                    // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
 | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 | ||||||
|     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
 |     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
 | ||||||
| 
 | 
 | ||||||
| @ -52,14 +52,15 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { | |||||||
|     int dH = INT_ARG(10);                                                       // dilations height
 |     int dH = INT_ARG(10);                                                       // dilations height
 | ||||||
|     int dW = INT_ARG(11);                                                       // dilations width
 |     int dW = INT_ARG(11);                                                       // dilations width
 | ||||||
|     int paddingMode = INT_ARG(12);                                              // 0-SAME,  1-VALID
 |     int paddingMode = INT_ARG(12);                                              // 0-SAME,  1-VALID
 | ||||||
|     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 |     int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0;        // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); |     REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| @ -71,14 +72,24 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { | |||||||
|     std::vector<int> permutForOutput; |     std::vector<int> permutForOutput; | ||||||
| 
 | 
 | ||||||
|     if (isNCDHW) |     if (isNCDHW) | ||||||
|         permutForOutput    = {0,2,3,4,1};                                        // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
 |         permutForOutput = {0,2,3,4,1};                                        // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
 | ||||||
|     else |     else | ||||||
|         input = new NDArray(input->permute({0,4,1,2,3})); |         input = new NDArray(input->permute({0,4,1,2,3})); | ||||||
| 
 | 
 | ||||||
|  |     std::vector<int> wAxes; | ||||||
|  |     if(0 == wFormat) | ||||||
|  |         wAxes = {3,0,1,2}; | ||||||
|  |     else if(1 == wFormat) | ||||||
|  |         wAxes = {1,2,3,4}; | ||||||
|  |     else | ||||||
|  |         wAxes = {4,1,2,3}; | ||||||
|  | 
 | ||||||
|     NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); |     NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); | ||||||
|     ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);                 // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
 |     ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);                 // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
 | ||||||
|     // [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC]
 |     // [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC]
 | ||||||
|     MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, {3,0,1,2}, permutForOutput); |     // [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, iC, kD, kH, kW] = [bS, oD, oH, oW, oC]
 | ||||||
|  |     // [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, kD, kH, kW, iC] = [bS, oD, oH, oW, oC]
 | ||||||
|  |     MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, wAxes, permutForOutput); | ||||||
| 
 | 
 | ||||||
|     if(bias) |     if(bias) | ||||||
|         // output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
 |         // output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
 | ||||||
| @ -101,7 +112,7 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { | |||||||
| DECLARE_SHAPE_FN(conv3dnew) { | DECLARE_SHAPE_FN(conv3dnew) { | ||||||
| 
 | 
 | ||||||
|     auto inputShapeInfo   = inputShape->at(0);                                  // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 |     auto inputShapeInfo   = inputShape->at(0);                                  // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 | ||||||
|     auto weightsShapeInfo = inputShape->at(1);                                  // [kD, kH, kW, iC, oC] always
 |     auto weightsShapeInfo = inputShape->at(1);                                  // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
 | ||||||
|     auto biasShapeInfo    = block.width() > 2 ? inputShape->at(2) : nullptr;    // [oC]
 |     auto biasShapeInfo    = block.width() > 2 ? inputShape->at(2) : nullptr;    // [oC]
 | ||||||
| 
 | 
 | ||||||
|     int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth
 |     int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth
 | ||||||
| @ -118,13 +129,14 @@ DECLARE_SHAPE_FN(conv3dnew) { | |||||||
|     int dW = INT_ARG(11);                                                       // dilations width
 |     int dW = INT_ARG(11);                                                       // dilations width
 | ||||||
|     int paddingMode = INT_ARG(12);                                              // 1-SAME,  0-VALID;
 |     int paddingMode = INT_ARG(12);                                              // 1-SAME,  0-VALID;
 | ||||||
|     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 |     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0;         // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     const int rank = 5; |     const int rank = 5; | ||||||
|     REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); |     REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); | ||||||
|     REQUIRE_TRUE(inputShapeInfo[0]   == rank, 0, "CUSTOM CONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo); |     REQUIRE_TRUE(inputShapeInfo[0]   == rank, 0, "CUSTOM CONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo); | ||||||
|     REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); |     REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiD, indWoC(4); |     int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0); | ||||||
|     if(!isNCDHW) { |     if(!isNCDHW) { | ||||||
|         indIOioC = 4; indIiD = 1; |         indIOioC = 4; indIiD = 1; | ||||||
|     } |     } | ||||||
| @ -139,7 +151,7 @@ DECLARE_SHAPE_FN(conv3dnew) { | |||||||
|     int iC = inputShapeInfo[indIOioC+1];                  // input channels
 |     int iC = inputShapeInfo[indIOioC+1];                  // input channels
 | ||||||
|     int oC = weightsShapeInfo[indWoC+1];                  // output channels
 |     int oC = weightsShapeInfo[indWoC+1];                  // output channels
 | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
|     if (biasShapeInfo) |     if (biasShapeInfo) | ||||||
|         REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); |         REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); | ||||||
| @ -174,12 +186,12 @@ DECLARE_SHAPE_FN(conv3dnew) { | |||||||
| CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { | CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kD, kH, kW, iC, oC] always
 |     auto weights = INPUT_VARIABLE(1);                                                // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
 | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
 |     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
 | ||||||
|     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kD, kH, kW, iC, oC] always
 |     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(input->rankOf()   == 5, 0, "CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); |     REQUIRE_TRUE(input->rankOf()   == 5, 0, "CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); | ||||||
| @ -200,17 +212,18 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { | |||||||
|     int dW = INT_ARG(11);                                                       // dilations width
 |     int dW = INT_ARG(11);                                                       // dilations width
 | ||||||
|     int paddingMode = INT_ARG(12);                                              // 1-SAME,  0-VALID
 |     int paddingMode = INT_ARG(12);                                              // 1-SAME,  0-VALID
 | ||||||
|     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 |     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0;         // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     int trueoD, trueoH, trueoW;          // true output depth/height/width
 |     int trueoD, trueoH, trueoW;          // true output depth/height/width
 | ||||||
|     ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); |     ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D_BP OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); |     REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D_BP OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if(bias) |     if(bias) | ||||||
| @ -231,10 +244,25 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { | |||||||
|         gradOaxesForDot  = {0,2,3,4};                                           // bS, oD, oH, oW
 |         gradOaxesForDot  = {0,2,3,4};                                           // bS, oD, oH, oW
 | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     std::vector<int> wPermut, colPermut; | ||||||
|  | 
 | ||||||
|  |     if(0 == wFormat) { | ||||||
|  |         wPermut   = {3,0,1,2,4}; | ||||||
|  |         colPermut = {2,3,4,1,0,5,6,7}; | ||||||
|  |     } | ||||||
|  |     else if(1 == wFormat) { | ||||||
|  |         wPermut   = {1,2,3,4,0}; | ||||||
|  |         colPermut = {1,2,3,4,0,5,6,7}; | ||||||
|  |     } | ||||||
|  |     else { | ||||||
|  |         wPermut   = {4,1,2,3,0}; | ||||||
|  |         colPermut = {2,3,4,1,0,5,6,7}; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     // ----- calculation of gradW and gradB ----- //
 |     // ----- calculation of gradW and gradB ----- //
 | ||||||
|     NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); |     NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); | ||||||
|     ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);                   // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
 |     ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);                   // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
 | ||||||
|     MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4});     // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
 |     MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, wPermut);     // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
 | ||||||
| 
 | 
 | ||||||
|     //----- calculation of gradO -----//
 |     //----- calculation of gradO -----//
 | ||||||
|     if(gradB) { |     if(gradB) { | ||||||
| @ -246,7 +274,10 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     //----- calculation of gradI -----//
 |     //----- calculation of gradI -----//
 | ||||||
|     MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2,3,4,1,0,5,6,7});   // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
 |     // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
 | ||||||
|  |     // [oC, iC, kD, kH, kW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
 | ||||||
|  |     // [oC, kD, kH, kW, iC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
 | ||||||
|  |     MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); | ||||||
|     ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW);                   // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to  [bS, iC, iD, iH, iW]
 |     ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW);                   // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to  [bS, iC, iD, iH, iW]
 | ||||||
| 
 | 
 | ||||||
|     if(!isNCDHW) { |     if(!isNCDHW) { | ||||||
| @ -270,7 +301,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { | |||||||
| DECLARE_SHAPE_FN(conv3dnew_bp) { | DECLARE_SHAPE_FN(conv3dnew_bp) { | ||||||
| 
 | 
 | ||||||
|     Nd4jLong* inputShapeInfo   = inputShape->at(0);                                              // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 |     Nd4jLong* inputShapeInfo   = inputShape->at(0);                                              // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 | ||||||
|     Nd4jLong* weightsShapeInfo = inputShape->at(1);                                              // [kD, kH, kW, iC, oC] always
 |     Nd4jLong* weightsShapeInfo = inputShape->at(1);                                              // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
 | ||||||
|     Nd4jLong* biasShapeInfo    = block.width() > 3 ? inputShape->at(2) : nullptr;                // [oC]
 |     Nd4jLong* biasShapeInfo    = block.width() > 3 ? inputShape->at(2) : nullptr;                // [oC]
 | ||||||
|     Nd4jLong* gradOShapeInfo   = block.width() > 3 ? inputShape->at(3) : inputShape->at(2);      // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 |     Nd4jLong* gradOShapeInfo   = block.width() > 3 ? inputShape->at(3) : inputShape->at(2);      // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
| @ -288,6 +319,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) { | |||||||
|     int dW = INT_ARG(11);                                                       // dilations width
 |     int dW = INT_ARG(11);                                                       // dilations width
 | ||||||
|     int paddingMode = INT_ARG(12);                                               // 1-SAME,  0-VALID
 |     int paddingMode = INT_ARG(12);                                               // 1-SAME,  0-VALID
 | ||||||
|     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 |     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0;         // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     const int rank = 5; |     const int rank = 5; | ||||||
|     REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); |     REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); | ||||||
| @ -295,7 +327,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) { | |||||||
|     REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); |     REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); | ||||||
|     REQUIRE_TRUE(gradOShapeInfo[0]   == rank, 0, "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo); |     REQUIRE_TRUE(gradOShapeInfo[0]   == rank, 0, "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo); | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiD, indWoC(4); |     int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0); | ||||||
|     if(!isNCDHW) { |     if(!isNCDHW) { | ||||||
|         indIOioC = 4; indIiD = 1; |         indIOioC = 4; indIiD = 1; | ||||||
|     } |     } | ||||||
| @ -314,7 +346,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) { | |||||||
|     ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); |     ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW,  0,indIOioC,indIiD,indIiD+1,indIiD+2}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW,  0,indIOioC,indIiD,indIiD+1,indIiD+2}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape),   0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape),   0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
|     if(biasShapeInfo) |     if(biasShapeInfo) | ||||||
|  | |||||||
| @ -35,7 +35,7 @@ namespace ops  { | |||||||
| CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { | CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, oC, iC] always
 |     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
 | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 | ||||||
| 
 | 
 | ||||||
|     auto output  = OUTPUT_NULLIFIED(0);                                   // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 |     auto output  = OUTPUT_NULLIFIED(0);                                   // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 | ||||||
| @ -53,12 +53,13 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW     = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 |     int isNCHW     = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| @ -66,6 +67,12 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { | |||||||
|     if(!isNCHW) |     if(!isNCHW) | ||||||
|         output = new NDArray(output->permute({0, 3, 1, 2}));       // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
 |         output = new NDArray(output->permute({0, 3, 1, 2}));       // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
 | ||||||
| 
 | 
 | ||||||
|  |     std::vector<int> colPermut; | ||||||
|  |     if(1 == wFormat) | ||||||
|  |         colPermut = {1, 2, 3, 0, 4, 5}; | ||||||
|  |     else | ||||||
|  |         colPermut = {2, 3, 1, 0, 4, 5}; | ||||||
|  | 
 | ||||||
|     if(isSameMode)          // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
 |     if(isSameMode)          // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
 | ||||||
|         ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); |         ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); | ||||||
| 
 | 
 | ||||||
| @ -73,8 +80,9 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { | |||||||
| 
 | 
 | ||||||
|     //----- calculation of output -----//
 |     //----- calculation of output -----//
 | ||||||
|     // NHWC: [kH, kW, oC, iC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW]
 |     // NHWC: [kH, kW, oC, iC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW]
 | ||||||
|     // NCHW: [kH, kW, oC, iC] x [bS, iC, iH, iW] = [kH, kW, oC, bS, iH, iW]
 |     // NHWC: [iC, oC, kH, kW] x [bS, iH, iW, iC] = [oC, kH, kW, bS, iH, iW]
 | ||||||
|     sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 1, 0, 4, 5}); |     // NHWC: [iC, kH, kW, oC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW]
 | ||||||
|  |     sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, colPermut); | ||||||
|     LaunchContext* ctx = block.launchContext(); |     LaunchContext* ctx = block.launchContext(); | ||||||
|     helpers::col2im(*ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, dW);     // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW]
 |     helpers::col2im(*ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, dW);     // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW]
 | ||||||
| 
 | 
 | ||||||
| @ -97,7 +105,7 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { | |||||||
| DECLARE_SHAPE_FN(deconv2d) { | DECLARE_SHAPE_FN(deconv2d) { | ||||||
| 
 | 
 | ||||||
|     auto inputShapeInfo   = inputShape->at(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     auto inputShapeInfo   = inputShape->at(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|     auto weightsShapeInfo = inputShape->at(1);                                    // [kH, kW, oC, iC] always
 |     auto weightsShapeInfo = inputShape->at(1);                                    // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
 | ||||||
|     auto biasShapeInfo    = block.width() > 2 ? inputShape->at(2) : nullptr;      // [oC]
 |     auto biasShapeInfo    = block.width() > 2 ? inputShape->at(2) : nullptr;      // [oC]
 | ||||||
| 
 | 
 | ||||||
|     const int rank = 4; |     const int rank = 4; | ||||||
| @ -114,8 +122,9 @@ DECLARE_SHAPE_FN(deconv2d) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
 | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiH, indWoC(2); |     int indIOioC, indIiH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3)); | ||||||
|     if(!isNCHW) { |     if(!isNCHW) { | ||||||
|         indIOioC = 3; indIiH = 1; |         indIOioC = 3; indIiH = 1; | ||||||
|     } |     } | ||||||
| @ -129,7 +138,7 @@ DECLARE_SHAPE_FN(deconv2d) { | |||||||
|     const int iC = inputShapeInfo[indIOioC+1];                   // input channels
 |     const int iC = inputShapeInfo[indIOioC+1];                   // input channels
 | ||||||
|     const int oC = weightsShapeInfo[indWoC+1];                   // output channels
 |     const int oC = weightsShapeInfo[indWoC+1];                   // output channels
 | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); | ||||||
|     REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
|     if (biasShapeInfo) |     if (biasShapeInfo) | ||||||
|         REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); |         REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); | ||||||
| @ -163,12 +172,12 @@ DECLARE_SHAPE_FN(deconv2d) { | |||||||
| CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { | CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
 |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, oC, iC] always
 |     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
 | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
 |     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
 | ||||||
|     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kH, kW, oC, iC] always
 |     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(input->rankOf()   == 4, 0, "CUSTOM DECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); |     REQUIRE_TRUE(input->rankOf()   == 4, 0, "CUSTOM DECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); | ||||||
| @ -186,16 +195,17 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     int trueoH, trueoW;          // true output height, width
 |     int trueoH, trueoW;          // true output height, width
 | ||||||
|     ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); |     ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM DECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM DECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if(bias) |     if(bias) | ||||||
| @ -206,29 +216,34 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { | |||||||
|         ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); |         ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 |     // ----- calculation of gradI -> pass it through conv2d_ff ----- //
 | ||||||
|      // ----- calculation of gradI -> pass it through conv2d_ff ----- //
 |  | ||||||
|     sd::ops::conv2d conv2d; |     sd::ops::conv2d conv2d; | ||||||
|     const Nd4jStatus status = conv2d.execute({gradO, weights}, {gradI}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW,  isSameMode,  !isNCHW}, {}); |     const Nd4jStatus status = conv2d.execute({gradO, weights}, {gradI}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW,  isSameMode, !isNCHW, wFormat}, {}); | ||||||
|     if (status != ND4J_STATUS_OK) |     if (status != ND4J_STATUS_OK) | ||||||
|         return status; |         return status; | ||||||
| 
 | 
 | ||||||
|     // -----prepare permutation arrays and axes for dot product ----- //
 |     // -----prepare permutation arrays and axes for dot product ----- //
 | ||||||
|     std::vector<int> inputAxesForDot; |     std::vector<int> inputAxes; | ||||||
| 
 | 
 | ||||||
|     if(!isNCHW) { |     if(!isNCHW) { | ||||||
|         gradO = new NDArray(gradO->permute({0, 3, 1, 2}));                      // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
 |         gradO = new NDArray(gradO->permute({0, 3, 1, 2}));                      // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
 | ||||||
|         inputAxesForDot = {0, 1, 2};                                            // bS, iH, iW
 |         inputAxes = {0, 1, 2};                                            // bS, iH, iW
 | ||||||
|     } |     } | ||||||
|     else |     else | ||||||
|         inputAxesForDot = {0, 2, 3};                                            // bS, iH, iW
 |         inputAxes = {0, 2, 3};                                            // bS, iH, iW
 | ||||||
|  | 
 | ||||||
|  |     std::vector<int> gradWAxes;     // empty for wFormat = 1
 | ||||||
|  |     if(0 == wFormat) | ||||||
|  |         gradWAxes = {3, 2, 0, 1}; | ||||||
|  |     else if(2 == wFormat) | ||||||
|  |         gradWAxes = {0, 3, 1, 2}; | ||||||
| 
 | 
 | ||||||
|     // ----- calculation of gradW ----- //
 |     // ----- calculation of gradW ----- //
 | ||||||
|     NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext()); |     NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext()); | ||||||
| 
 | 
 | ||||||
|     LaunchContext* ctx = block.launchContext(); |     LaunchContext* ctx = block.launchContext(); | ||||||
|     helpers::im2col(*ctx, *gradO, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, oC, oH, oW] is convoluted to [bS, oC, kH, kW, iH, iW]
 |     helpers::im2col(*ctx, *gradO, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, oC, oH, oW] is convoluted to [bS, oC, kH, kW, iH, iW]
 | ||||||
|     MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 4, 5}, {3, 2, 0, 1});     // [bS, iC, iH, iW]/[bS, iH, iW, iC] x [bS, oC, kH, kW, iH, iW] = [iC, oC, kH, kW]
 |     MmulHelper::tensorDot(input, &columns, gradW, inputAxes, {0, 4, 5}, gradWAxes);     // [bS, iC, iH, iW]/[bS, iH, iW, iC] x [bS, oC, kH, kW, iH, iW] = [iC, oC, kH, kW]
 | ||||||
| 
 | 
 | ||||||
|     // ----- calculation of gradB ----- //
 |     // ----- calculation of gradB ----- //
 | ||||||
|     if(gradB) { |     if(gradB) { | ||||||
| @ -248,7 +263,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { | |||||||
| DECLARE_SHAPE_FN(deconv2d_bp) { | DECLARE_SHAPE_FN(deconv2d_bp) { | ||||||
| 
 | 
 | ||||||
|     auto inputShapeInfo   = inputShape->at(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
 |     auto inputShapeInfo   = inputShape->at(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
 | ||||||
|     auto weightsShapeInfo = inputShape->at(1);                                                // [kH, kW, oC, iC] always
 |     auto weightsShapeInfo = inputShape->at(1);                                                // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
 | ||||||
|     Nd4jLong* biasShapeInfo    = block.width() > 3 ? inputShape->at(2) : nullptr;             // [oC]
 |     Nd4jLong* biasShapeInfo    = block.width() > 3 ? inputShape->at(2) : nullptr;             // [oC]
 | ||||||
|     Nd4jLong* gradOShapeInfo   = block.width() > 3 ? inputShape->at(3) : inputShape->at(2);   // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 |     Nd4jLong* gradOShapeInfo   = block.width() > 3 ? inputShape->at(3) : inputShape->at(2);   // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
| @ -267,8 +282,9 @@ DECLARE_SHAPE_FN(deconv2d_bp) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
 | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiH, indWoC(2), indOoH; |     int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3)); | ||||||
|     if(!isNCHW) { |     if(!isNCHW) { | ||||||
|         indIOioC = 3; indIiH = 1; indOoH = 1; |         indIOioC = 3; indIiH = 1; indOoH = 1; | ||||||
|     } |     } | ||||||
| @ -286,7 +302,7 @@ DECLARE_SHAPE_FN(deconv2d_bp) { | |||||||
|     ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); |     ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); | ||||||
|     REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0,  "CUSTOM DECONV2D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); |     REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0,  "CUSTOM DECONV2D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); | ||||||
|     REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
|     if(biasShapeInfo) |     if(biasShapeInfo) | ||||||
|  | |||||||
| @ -32,10 +32,10 @@ namespace ops  { | |||||||
| CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { | CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { | ||||||
| 
 | 
 | ||||||
|     auto gradO      = INPUT_VARIABLE(2);                                                // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 |     auto gradO      = INPUT_VARIABLE(2);                                                // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 | ||||||
|     auto weights    = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC] always
 |     auto weights    = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto gradIShape = INPUT_VARIABLE(0);                                                // [4] - shape of input of conv2d (that is shape of gradI)
 |     auto gradIShape = INPUT_VARIABLE(0);                                                // [4] - shape of input of conv2d (that is shape of gradI)
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_NULLIFIED(0);                                                  // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 |     auto gradI = OUTPUT_NULLIFIED(0);                                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 | ||||||
| 
 | 
 | ||||||
|     int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
 |     int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
 | ||||||
|     int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
 |     int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
 | ||||||
| @ -47,6 +47,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     const int rank = gradO->rankOf(); |     const int rank = gradO->rankOf(); | ||||||
| 
 | 
 | ||||||
| @ -57,20 +58,19 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { | |||||||
|     // create empty conv2d input array
 |     // create empty conv2d input array
 | ||||||
|     NDArray input(gradO->ordering(), gradIShape->asVectorT<Nd4jLong>(), gradO->dataType(), block.launchContext()); |     NDArray input(gradO->ordering(), gradIShape->asVectorT<Nd4jLong>(), gradO->dataType(), block.launchContext()); | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     int trueoH, trueoW;          // true output height, width
 |     int trueoH, trueoW;          // true output height, width
 | ||||||
|     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); |     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); |     ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -84,7 +84,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { | |||||||
| DECLARE_SHAPE_FN(deconv2d_tf) { | DECLARE_SHAPE_FN(deconv2d_tf) { | ||||||
| 
 | 
 | ||||||
|     auto gradOShapeInfo   = inputShape->at(2);                                                // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 |     auto gradOShapeInfo   = inputShape->at(2);                                                // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 | ||||||
|     auto weightsShapeInfo = inputShape->at(1);                                                // [kH, kW, iC, oC] always
 |     auto weightsShapeInfo = inputShape->at(1);                                                // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto gradIShapeShapeInfo = inputShape->at(0);                                             // [4]
 |     auto gradIShapeShapeInfo = inputShape->at(0);                                             // [4]
 | ||||||
| 
 | 
 | ||||||
|     const int rank = 4; |     const int rank = 4; | ||||||
| @ -103,8 +103,9 @@ DECLARE_SHAPE_FN(deconv2d_tf) { | |||||||
|     const int dW = INT_ARG(7);                                                        // dilations width
 |     const int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     const int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     const int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     const int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 |     const int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 | ||||||
|  |     const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiH, indWoC(3), indOoH; |     int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH; | ||||||
|     if(!isNCHW) { |     if(!isNCHW) { | ||||||
|         indIOioC = 3; indIiH = 1; indOoH = 1; |         indIOioC = 3; indIiH = 1; indOoH = 1; | ||||||
|     } |     } | ||||||
| @ -126,7 +127,7 @@ DECLARE_SHAPE_FN(deconv2d_tf) { | |||||||
|     ConvolutionUtils::calcOutSizeDeconv2D(trueiH, trueiW, kH, kW, sH, sW, pH, pW, dH, dW, oH, oW, isSameMode); |     ConvolutionUtils::calcOutSizeDeconv2D(trueiH, trueiW, kH, kW, sH, sW, pH, pW, dH, dW, oH, oW, isSameMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,trueiH,trueiW,  0,indIOioC,indIiH,indIiH+1}); |     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,trueiH,trueiW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(expectedGradIShape == gradIShape, 0,  "CUSTOM DECONV2D_TF OP: wrong shape of array with output shape, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradIShape).c_str()); |     REQUIRE_TRUE(expectedGradIShape == gradIShape, 0,  "CUSTOM DECONV2D_TF OP: wrong shape of array with output shape, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradIShape).c_str()); | ||||||
|     REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -32,7 +32,7 @@ namespace ops  { | |||||||
| CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { | CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kD, kH, kW, oC, iC] always
 |     auto weights = INPUT_VARIABLE(1);                                    // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
 | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 | ||||||
| 
 | 
 | ||||||
|     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
 |     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
 | ||||||
| @ -53,13 +53,14 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { | |||||||
|     int dH = INT_ARG(10);                                                           // dilations height
 |     int dH = INT_ARG(10);                                                           // dilations height
 | ||||||
|     int dW = INT_ARG(11);                                                           // dilations width
 |     int dW = INT_ARG(11);                                                           // dilations width
 | ||||||
|     int isSameMode = INT_ARG(12);                                                   // 0-SAME,  1-VALID
 |     int isSameMode = INT_ARG(12);                                                   // 0-SAME,  1-VALID
 | ||||||
|     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;           // INT_ARG(13): 1-NDHWC, 0-NCDHW
 |     int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;            // INT_ARG(13): 1-NDHWC, 0-NCDHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0;             // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| @ -67,16 +68,23 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { | |||||||
|     if(!isNCDHW) |     if(!isNCDHW) | ||||||
|         output = new NDArray(output->permute({0, 4, 1, 2, 3}));                 // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
 |         output = new NDArray(output->permute({0, 4, 1, 2, 3}));                 // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
 | ||||||
| 
 | 
 | ||||||
|  |     std::vector<int> colPermut; | ||||||
|  |     if(1 == wFormat) | ||||||
|  |         colPermut = {1,2,3,4,0,5,6,7}; | ||||||
|  |     else | ||||||
|  |         colPermut = {2,3,4,1,0,5,6,7}; | ||||||
|  | 
 | ||||||
|     if(isSameMode)         // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
 |     if(isSameMode)         // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
 | ||||||
|         ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); |         ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); | ||||||
| 
 | 
 | ||||||
|     NDArray columns(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(),  block.launchContext()); |     NDArray columns(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(),  block.launchContext()); | ||||||
| 
 | 
 | ||||||
|     //----- calculation of output -----//
 |     //----- calculation of output -----//
 | ||||||
|     // NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
 |     // [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
 | ||||||
|     // NCDHW: [kD, kH, kW, oC, iC] x [bS, iC, iD, iH, iW] = [kD, kH, kW, oC, bS, iD, iH, iW]
 |     // [iC, oC, kD, kH, kW] x [bS, iD, iH, iW, iC] = [oC, kD, kH, kW, bS, iD, iH, iW]
 | ||||||
|     sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7});   // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
 |     // [iC, kD, kH, kW, oC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
 | ||||||
|     ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW);                   // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
 |     sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, colPermut);       // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
 | ||||||
|  |     ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW);     // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
 | ||||||
| 
 | 
 | ||||||
|     //----- add biases if required -----//
 |     //----- add biases if required -----//
 | ||||||
|     if(bias) |     if(bias) | ||||||
| @ -101,7 +109,7 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { | |||||||
| DECLARE_SHAPE_FN(deconv3d) { | DECLARE_SHAPE_FN(deconv3d) { | ||||||
| 
 | 
 | ||||||
|     auto inputShapeInfo   = inputShape->at(0);                                    // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NDCHW)
 |     auto inputShapeInfo   = inputShape->at(0);                                    // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NDCHW)
 | ||||||
|     auto weightsShapeInfo = inputShape->at(1);                                    // [kD, kH, kW, oC, iC] always
 |     auto weightsShapeInfo = inputShape->at(1);                                    // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
 | ||||||
|     auto biasShapeInfo    = block.width() > 2 ? inputShape->at(2) : nullptr;      // [oC]
 |     auto biasShapeInfo    = block.width() > 2 ? inputShape->at(2) : nullptr;      // [oC]
 | ||||||
| 
 | 
 | ||||||
|     const int rank = 5; |     const int rank = 5; | ||||||
| @ -122,8 +130,9 @@ DECLARE_SHAPE_FN(deconv3d) { | |||||||
|     int dW = INT_ARG(11);                                                       // dilations width
 |     int dW = INT_ARG(11);                                                       // dilations width
 | ||||||
|     int isSameMode = INT_ARG(12);                                               // 0-SAME,  1-VALID
 |     int isSameMode = INT_ARG(12);                                               // 0-SAME,  1-VALID
 | ||||||
|     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 |     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0;         // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
 | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiD, indWoC(3); |     int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4)); | ||||||
|     if(!isNCDHW) { |     if(!isNCDHW) { | ||||||
|         indIOioC = 4; indIiD = 1; |         indIOioC = 4; indIiD = 1; | ||||||
|     } |     } | ||||||
| @ -138,7 +147,7 @@ DECLARE_SHAPE_FN(deconv3d) { | |||||||
|     const int iC = inputShapeInfo[indIOioC+1];                  // input channels
 |     const int iC = inputShapeInfo[indIOioC+1];                  // input channels
 | ||||||
|     const int oC = weightsShapeInfo[indWoC+1];                  // output channels
 |     const int oC = weightsShapeInfo[indWoC+1];                  // output channels
 | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong>  expectedWeightsShape = {kD, kH, kW, oC, iC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); | ||||||
|     REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
|     if (biasShapeInfo) |     if (biasShapeInfo) | ||||||
|         REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); |         REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); | ||||||
| @ -174,12 +183,12 @@ DECLARE_SHAPE_FN(deconv3d) { | |||||||
| CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { | CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kD, kH, kW, oC, iC] always
 |     auto weights = INPUT_VARIABLE(1);                                                // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
 | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
 |     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
 | ||||||
|     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kD, kH, kW, oC, iC] always
 |     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(input->rankOf()   == 5, 0, "CUSTOM DECONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); |     REQUIRE_TRUE(input->rankOf()   == 5, 0, "CUSTOM DECONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); | ||||||
| @ -201,16 +210,17 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { | |||||||
|     int dW = INT_ARG(11);                                                       // dilations width
 |     int dW = INT_ARG(11);                                                       // dilations width
 | ||||||
|     int isSameMode = INT_ARG(12);                                               // 0-SAME,  1-VALID
 |     int isSameMode = INT_ARG(12);                                               // 0-SAME,  1-VALID
 | ||||||
|     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 |     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0;         // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); | ||||||
| 
 | 
 | ||||||
|     int trueoD, trueoH, trueoW;          // true output height, width
 |     int trueoD, trueoH, trueoW;          // true output height, width
 | ||||||
|     ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); |     ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if(bias) |     if(bias) | ||||||
| @ -221,7 +231,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { | |||||||
| 
 | 
 | ||||||
|      // ----- calculation of gradI -> pass it through conv3d_ff ----- //
 |      // ----- calculation of gradI -> pass it through conv3d_ff ----- //
 | ||||||
|     sd::ops::conv3dnew conv3d; |     sd::ops::conv3dnew conv3d; | ||||||
|     const Nd4jStatus status = conv3d.execute({gradO, weights}, {gradI}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW,  isSameMode,  !isNCDHW}, {}); |     const Nd4jStatus status = conv3d.execute({gradO, weights}, {gradI}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW,  isSameMode,  !isNCDHW, wFormat}, {}); | ||||||
|     if (status != ND4J_STATUS_OK) |     if (status != ND4J_STATUS_OK) | ||||||
|         return status; |         return status; | ||||||
| 
 | 
 | ||||||
| @ -235,10 +245,16 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { | |||||||
|     else |     else | ||||||
|         inputAxesForDot = {0, 2, 3, 4};                                         // bS, iD, iH, iW
 |         inputAxesForDot = {0, 2, 3, 4};                                         // bS, iD, iH, iW
 | ||||||
| 
 | 
 | ||||||
|  |     std::vector<int> gradWAxes;     // empty for wFormat = 1
 | ||||||
|  |     if(0 == wFormat) | ||||||
|  |         gradWAxes = {4,3,0,1,2}; | ||||||
|  |     else if(2 == wFormat) | ||||||
|  |         gradWAxes = {0,4,1,2,3}; | ||||||
|  | 
 | ||||||
|     // ----- calculation of gradW ----- //
 |     // ----- calculation of gradW ----- //
 | ||||||
|     auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW},  input->dataType(), block.launchContext()); |     auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW},  input->dataType(), block.launchContext()); | ||||||
|     ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);                  // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW]
 |     ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);     // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW]
 | ||||||
|     MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, {4, 3, 0, 1, 2});   // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW]
 |     MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, gradWAxes);   // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW]
 | ||||||
| 
 | 
 | ||||||
|     // ----- calculation of gradB ----- //
 |     // ----- calculation of gradB ----- //
 | ||||||
|     if(gradB) { |     if(gradB) { | ||||||
| @ -267,7 +283,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { | |||||||
| DECLARE_SHAPE_FN(deconv3d_bp) { | DECLARE_SHAPE_FN(deconv3d_bp) { | ||||||
| 
 | 
 | ||||||
|     auto inputShapeInfo   = inputShape->at(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 |     auto inputShapeInfo   = inputShape->at(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 | ||||||
|     auto weightsShapeInfo = inputShape->at(1);                                                // [kD, kH, kW, oC, iC] always
 |     auto weightsShapeInfo = inputShape->at(1);                                                // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
 | ||||||
|     Nd4jLong* biasShapeInfo    = block.width() > 3 ? inputShape->at(2) : nullptr;             // [oC]
 |     Nd4jLong* biasShapeInfo    = block.width() > 3 ? inputShape->at(2) : nullptr;             // [oC]
 | ||||||
|     Nd4jLong* gradOShapeInfo   = block.width() > 3 ? inputShape->at(3) : inputShape->at(2);   // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 |     Nd4jLong* gradOShapeInfo   = block.width() > 3 ? inputShape->at(3) : inputShape->at(2);   // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
| @ -290,8 +306,9 @@ DECLARE_SHAPE_FN(deconv3d_bp) { | |||||||
|     int dW = INT_ARG(11);                                                       // dilations width
 |     int dW = INT_ARG(11);                                                       // dilations width
 | ||||||
|     int isSameMode = INT_ARG(12);                                               // 0-SAME,  1-VALID
 |     int isSameMode = INT_ARG(12);                                               // 0-SAME,  1-VALID
 | ||||||
|     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 |     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0;         // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
 | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiD, indWoC(3); |     int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4)); | ||||||
|     if(!isNCDHW) { |     if(!isNCDHW) { | ||||||
|         indIOioC = 4; indIiD = 1; |         indIOioC = 4; indIiD = 1; | ||||||
|     } |     } | ||||||
| @ -310,8 +327,8 @@ DECLARE_SHAPE_FN(deconv3d_bp) { | |||||||
|     ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); |     ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW,  0,indIOioC,indIiD,indIiD+1,indIiD+2}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW,  0,indIOioC,indIiD,indIiD+1,indIiD+2}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); | ||||||
|     REQUIRE_TRUE(shape::shapeEquals(5, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0,  "CUSTOM DECONV3D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); |     REQUIRE_TRUE(shape::shapeEquals(5, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0,  "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); | ||||||
|     REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
|     if(biasShapeInfo) |     if(biasShapeInfo) | ||||||
|         REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); |         REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); | ||||||
|  | |||||||
| @ -32,7 +32,7 @@ namespace ops  { | |||||||
| CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { | CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, mC] always
 |     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC] = iC*mC
 |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC] = iC*mC
 | ||||||
| 
 | 
 | ||||||
|     auto output  = OUTPUT_NULLIFIED(0);                                   // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
 |     auto output  = OUTPUT_NULLIFIED(0);                                   // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
 | ||||||
| @ -50,19 +50,20 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW     = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 |     int isNCHW     = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
 |     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
 | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weights->sizeAt(indWmC);                           // channels multiplier
 |     mC = weights->sizeAt(indWmC);                           // channels multiplier
 | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !",  ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !",  ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); |     REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); |     ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -75,7 +76,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { | |||||||
| DECLARE_SHAPE_FN(depthwise_conv2d) { | DECLARE_SHAPE_FN(depthwise_conv2d) { | ||||||
| 
 | 
 | ||||||
|     Nd4jLong* inputShapeInfo   = inputShape->at(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     Nd4jLong* inputShapeInfo   = inputShape->at(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|     Nd4jLong* weightsShapeInfo = inputShape->at(1);                                    // [kH, kW, iC, mC] always
 |     Nd4jLong* weightsShapeInfo = inputShape->at(1);                                    // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|     Nd4jLong* biasShapeInfo    = block.width() > 2 ? inputShape->at(2) : nullptr;      // [oC] = iC*mC
 |     Nd4jLong* biasShapeInfo    = block.width() > 2 ? inputShape->at(2) : nullptr;      // [oC] = iC*mC
 | ||||||
| 
 | 
 | ||||||
|     const int rank = 4; |     const int rank = 4; | ||||||
| @ -92,8 +93,9 @@ DECLARE_SHAPE_FN(depthwise_conv2d) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiH, indWmC(3); |     int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); | ||||||
|     if(!isNCHW) { |     if(!isNCHW) { | ||||||
|         indIOioC = 3; indIiH = 1; |         indIOioC = 3; indIiH = 1; | ||||||
|     } |     } | ||||||
| @ -109,7 +111,7 @@ DECLARE_SHAPE_FN(depthwise_conv2d) { | |||||||
|     const int oC = iC*mC;                                       // output channels
 |     const int oC = iC*mC;                                       // output channels
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); | ||||||
|     REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
|     if (biasShapeInfo) |     if (biasShapeInfo) | ||||||
|         REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); |         REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); | ||||||
| @ -148,12 +150,12 @@ DECLARE_SHAPE_FN(depthwise_conv2d) { | |||||||
| CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) { | CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
 |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, mC] always
 |     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] = [iC*mC]
 |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] = [iC*mC]
 | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_NULLIFIED(0);                                                 // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
 |     auto gradI = OUTPUT_NULLIFIED(0);                                                 // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
 | ||||||
|     auto gradW = OUTPUT_NULLIFIED(1);                                                 // [kH, kW, iC, mC] always
 |     auto gradW = OUTPUT_NULLIFIED(1);                                                 // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(input->rankOf()   == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); |     REQUIRE_TRUE(input->rankOf()   == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); | ||||||
| @ -170,23 +172,24 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
 |     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
 | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weights->sizeAt(indWmC);                           // channels multiplier
 |     mC = weights->sizeAt(indWmC);                           // channels multiplier
 | ||||||
| 
 | 
 | ||||||
|     int trueoH, trueoW;          // correct output height, width
 |     int trueoH, trueoW;          // correct output height, width
 | ||||||
|     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); |     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if(bias) |     if(bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); |     ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -214,8 +217,9 @@ DECLARE_SHAPE_FN(depthwise_conv2d_bp) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiH, indWmC(3); |     int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); | ||||||
|     if(!isNCHW) { |     if(!isNCHW) { | ||||||
|         indIOioC = 3; indIiH = 1; |         indIOioC = 3; indIiH = 1; | ||||||
|     } |     } | ||||||
| @ -234,7 +238,7 @@ DECLARE_SHAPE_FN(depthwise_conv2d_bp) { | |||||||
|     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); |     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong>  expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indIiH,indIiH+1}); |     std::vector<Nd4jLong>  expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|     std::vector<Nd4jLong>  expectedWeightsShape = {kH, kW, iC, mC}; |     std::vector<Nd4jLong>  expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); | ||||||
|     REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0,  "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); |     REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0,  "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); | ||||||
|     REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
|     if(biasShapeInfo) |     if(biasShapeInfo) | ||||||
|  | |||||||
| @ -29,7 +29,7 @@ namespace ops  { | |||||||
| CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) { | CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [1,  1,  iC, oC] always
 |     auto weights = INPUT_VARIABLE(1);                                    // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC]
 | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 | ||||||
| 
 | 
 | ||||||
|     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, iH, iW, oC] (NHWC) or [bS, oC, iH, iW] (NCHW)
 |     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, iH, iW, oC] (NHWC) or [bS, oC, iH, iW] (NCHW)
 | ||||||
| @ -47,18 +47,19 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) { | |||||||
|     int pW = 0;                                                             // paddings width
 |     int pW = 0;                                                             // paddings width
 | ||||||
|     int dH = 1;                                                             // dilations height
 |     int dH = 1;                                                             // dilations height
 | ||||||
|     int dW = 1;                                                             // dilations width
 |     int dW = 1;                                                             // dilations width
 | ||||||
|     int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1;       // INT_ARG(0): 0-NCHW, 1-NHWC
 |     int isNCHW  = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1;      // INT_ARG(0): 0-NCHW, 1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 1 ? INT_ARG(1) : 0;       // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {1, 1, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW); |     ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -73,7 +74,7 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) { | |||||||
| DECLARE_SHAPE_FN(pointwise_conv2d) { | DECLARE_SHAPE_FN(pointwise_conv2d) { | ||||||
| 
 | 
 | ||||||
|     Nd4jLong* inputShapeInfo  = inputShape->at(0);                                   // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     Nd4jLong* inputShapeInfo  = inputShape->at(0);                                   // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|     Nd4jLong* weightsShapeInfo  = inputShape->at(1);                                 // [1,  1,  iC, oC] always
 |     Nd4jLong* weightsShapeInfo  = inputShape->at(1);                                 // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC]
 | ||||||
|     Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr;       // [oC]
 |     Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr;       // [oC]
 | ||||||
| 
 | 
 | ||||||
|     const int rank = 4; |     const int rank = 4; | ||||||
| @ -81,8 +82,9 @@ DECLARE_SHAPE_FN(pointwise_conv2d) { | |||||||
|     REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); |     REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); | ||||||
| 
 | 
 | ||||||
|     int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1;       // INT_ARG(0): 0-NCHW, 1-NHWC
 |     int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1;       // INT_ARG(0): 0-NCHW, 1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 1 ? INT_ARG(1) : 0;       // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC]
 | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indWoC(3); |     int indIOioC, indWoC(0 == wFormat ? 3 : 0); | ||||||
|     if(!isNCHW) |     if(!isNCHW) | ||||||
|         indIOioC = 3; |         indIOioC = 3; | ||||||
|     else |     else | ||||||
| @ -92,7 +94,7 @@ DECLARE_SHAPE_FN(pointwise_conv2d) { | |||||||
|     const int iC = inputShapeInfo[indIOioC+1];                   // input channels
 |     const int iC = inputShapeInfo[indIOioC+1];                   // input channels
 | ||||||
|     const int oC = weightsShapeInfo[indWoC+1];                   // output channels
 |     const int oC = weightsShapeInfo[indWoC+1];                   // output channels
 | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {1, 1, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); | ||||||
|     if (biasShapeInfo) |     if (biasShapeInfo) | ||||||
|         REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); |         REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); | ||||||
|  | |||||||
| @ -33,8 +33,8 @@ namespace ops  { | |||||||
| CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { | CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { | ||||||
| 
 | 
 | ||||||
|     NDArray *input        = INPUT_VARIABLE(0);                    // [bS, iH, iW, iC]  (NHWC) or [bS, iC, iH, iW]  (NCHW)
 |     NDArray *input        = INPUT_VARIABLE(0);                    // [bS, iH, iW, iC]  (NHWC) or [bS, iC, iH, iW]  (NCHW)
 | ||||||
|     NDArray *weightsDepth = INPUT_VARIABLE(1);                    // [kH, kW, iC, mC]  always
 |     NDArray *weightsDepth = INPUT_VARIABLE(1);                    // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|     NDArray *weightsPoint = nullptr;                              // [1, 1, iC*mC, oC] always
 |     NDArray *weightsPoint = nullptr;                              // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
 | ||||||
|     NDArray *bias         = nullptr;                              // [oC], if weightsPoint=nullptr then oC = iC*mC
 |     NDArray *bias         = nullptr;                              // [oC], if weightsPoint=nullptr then oC = iC*mC
 | ||||||
| 
 | 
 | ||||||
|     NDArray *output    = OUTPUT_NULLIFIED(0);                      // [bS, oH, oW, oC]  (NHWC) or [bS, oC, oH, oW]  (NCHW)
 |     NDArray *output    = OUTPUT_NULLIFIED(0);                      // [bS, oH, oW, oC]  (NHWC) or [bS, oC, oH, oW]  (NCHW)
 | ||||||
| @ -66,17 +66,19 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { | |||||||
|     int dH = INT_ARG(6);                                                        // dilations height
 |     int dH = INT_ARG(6);                                                        // dilations height
 | ||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW     = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 |     int isNCHW  = block.getIArguments()->size() > 9  ? !INT_ARG(9) : 1;         // INT_ARG(9): 0-NCHW,  1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
 | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
 |     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
 | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weightsDepth->sizeAt(indWmC);                      // channels multiplier
 |     mC = weightsDepth->sizeAt(indWmC);                      // channels multiplier
 | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC}; |     std::vector<Nd4jLong> expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); | ||||||
|     REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str()); |     REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str()); | ||||||
|     if(weightsPoint) { |     if(weightsPoint) { | ||||||
|         std::vector<Nd4jLong>  expectedWeightsPShape = {1, 1, iC*mC, oC}; |         std::vector<Nd4jLong>  expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); | ||||||
|         REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str()); |         REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str()); | ||||||
|     } |     } | ||||||
|     if (bias) |     if (bias) | ||||||
| @ -84,11 +86,11 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { | |||||||
| 
 | 
 | ||||||
|     if (iC == 1) { |     if (iC == 1) { | ||||||
|         nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n",""); |         nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n",""); | ||||||
|         ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); |         ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); | ||||||
|         return Status::OK(); |         return Status::OK(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); |     ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -103,8 +105,8 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { | |||||||
| DECLARE_SHAPE_FN(sconv2d) { | DECLARE_SHAPE_FN(sconv2d) { | ||||||
| 
 | 
 | ||||||
|     auto inputShapeInfo    = inputShape->at(0);         // [bS, iH, iW, iC]  (NHWC) or [bS, iC, iH, iW]  (NCHW)
 |     auto inputShapeInfo    = inputShape->at(0);         // [bS, iH, iW, iC]  (NHWC) or [bS, iC, iH, iW]  (NCHW)
 | ||||||
|     auto weightsDShapeInfo = inputShape->at(1);         // [kH, kW, iC, mC]  always
 |     auto weightsDShapeInfo = inputShape->at(1);         // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|     Nd4jLong* weightsPShapeInfo = nullptr;              // [1, 1, iC*mC, oC] always
 |     Nd4jLong* weightsPShapeInfo = nullptr;              // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
 | ||||||
|     Nd4jLong* biasShapeInfo     = nullptr;              // [oC], oC = iC*mC if weightsPoint=nullptr
 |     Nd4jLong* biasShapeInfo     = nullptr;              // [oC], oC = iC*mC if weightsPoint=nullptr
 | ||||||
| 
 | 
 | ||||||
|     if(block.width() == 3) |     if(block.width() == 3) | ||||||
| @ -135,8 +137,9 @@ DECLARE_SHAPE_FN(sconv2d) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiH, indWmC(3); |     int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); | ||||||
|     if(!isNCHW) { |     if(!isNCHW) { | ||||||
|         indIOioC = 3; indIiH = 1; |         indIOioC = 3; indIiH = 1; | ||||||
|     } |     } | ||||||
| @ -148,13 +151,13 @@ DECLARE_SHAPE_FN(sconv2d) { | |||||||
|     const int iH = inputShapeInfo[indIiH+1];                                        // input height
 |     const int iH = inputShapeInfo[indIiH+1];                                        // input height
 | ||||||
|     const int iW = inputShapeInfo[indIiH+2];                                        // input width
 |     const int iW = inputShapeInfo[indIiH+2];                                        // input width
 | ||||||
|     const int iC = inputShapeInfo[indIOioC+1];                                      // input channels
 |     const int iC = inputShapeInfo[indIOioC+1];                                      // input channels
 | ||||||
|     const int mC = weightsDShapeInfo[indWmC+1];                                      // channel multiplier
 |     const int mC = weightsDShapeInfo[indWmC+1];                                     // channel multiplier
 | ||||||
|     const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC+1] : iC*mC;       // output channels (oC or iC*mC)
 |     const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC+1] : iC*mC;         // output channels (oC or iC*mC)
 | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong>  expectedWeightsDShape = {kH, kW, iC, mC}; |     std::vector<Nd4jLong> expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); | ||||||
|     if(weightsPShapeInfo) { |     if(weightsPShapeInfo) { | ||||||
|         std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC}; |         std::vector<Nd4jLong> expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); | ||||||
|         REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); |         REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); | ||||||
|     } |     } | ||||||
|     if (biasShapeInfo) |     if (biasShapeInfo) | ||||||
| @ -195,13 +198,13 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { | |||||||
| 
 | 
 | ||||||
|     NDArray *input        = INPUT_VARIABLE(0);                                           // [bS, iH, iW, iC]  (NHWC) or [bS, iC, iH, iW]  (NCHW)
 |     NDArray *input        = INPUT_VARIABLE(0);                                           // [bS, iH, iW, iC]  (NHWC) or [bS, iC, iH, iW]  (NCHW)
 | ||||||
|     NDArray *gradO        = INPUT_VARIABLE(1);                                           // [bS, oH, oW, oC]  (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 |     NDArray *gradO        = INPUT_VARIABLE(1);                                           // [bS, oH, oW, oC]  (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 | ||||||
|     NDArray *weightsDepth = INPUT_VARIABLE(2);                                           // [kH, kW, iC, mC] always
 |     NDArray *weightsDepth = INPUT_VARIABLE(2);                                           // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|     NDArray *weightsPoint = nullptr;                                                     // [1, 1, iC*mC, oC] always
 |     NDArray *weightsPoint = nullptr;                                                     // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
 | ||||||
|     NDArray *bias         = nullptr;                                                     // [oC], oC = iC*mC if weightsPoint=nullptr
 |     NDArray *bias         = nullptr;                                                     // [oC], oC = iC*mC if weightsPoint=nullptr
 | ||||||
| 
 | 
 | ||||||
|     NDArray *gradI  = OUTPUT_NULLIFIED(0);                                                // [bS, iH, iW, iC]  (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 |     NDArray *gradI  = OUTPUT_NULLIFIED(0);                                                // [bS, iH, iW, iC]  (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 | ||||||
|     NDArray *gradWD = OUTPUT_NULLIFIED(1);                                                // [kH, kW, iC, mC] always
 |     NDArray *gradWD = OUTPUT_NULLIFIED(1);                                                // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|     NDArray *gradWP = nullptr;                                                           // [1, 1, iC*mC, oC] always
 |     NDArray *gradWP = nullptr;                                                           // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
 | ||||||
|     NDArray *gradB  = nullptr;                                                           // [oC]
 |     NDArray *gradB  = nullptr;                                                           // [oC]
 | ||||||
| 
 | 
 | ||||||
|     if(block.width() == 4) { |     if(block.width() == 4) { | ||||||
| @ -244,17 +247,18 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW     = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 |     int isNCHW     = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
 |     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
 | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weightsDepth->sizeAt(indWmC);                      // channels multiplier
 |     mC = weightsDepth->sizeAt(indWmC);                      // channels multiplier
 | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC}; |     std::vector<Nd4jLong> expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); | ||||||
|     REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str()); |     REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str()); | ||||||
|     REQUIRE_TRUE(gradWD->isSameShape(expectedWeightsDShape),       0, " SCONV2D_BP OP: wrong shape of gradWD array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(gradWD).c_str()); |     REQUIRE_TRUE(gradWD->isSameShape(expectedWeightsDShape),       0, " SCONV2D_BP OP: wrong shape of gradWD array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(gradWD).c_str()); | ||||||
|     if(weightsPoint) { |     if(weightsPoint) { | ||||||
|         std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC}; |         std::vector<Nd4jLong> expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); | ||||||
|         REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str()); |         REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str()); | ||||||
|         REQUIRE_TRUE(gradWP->isSameShape(expectedWeightsPShape),       0, " SCONV2D_BP OP: wrong shape of gradWP array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(gradWP).c_str()); |         REQUIRE_TRUE(gradWP->isSameShape(expectedWeightsPShape),       0, " SCONV2D_BP OP: wrong shape of gradWP array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(gradWP).c_str()); | ||||||
|     } |     } | ||||||
| @ -274,12 +278,12 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { | |||||||
| 
 | 
 | ||||||
|         auto resultFFShape = isNCHW ? std::vector<Nd4jLong>({bS, mC*iC, oH, oW}) : std::vector<Nd4jLong>({bS, oH, oW, mC*iC}); |         auto resultFFShape = isNCHW ? std::vector<Nd4jLong>({bS, mC*iC, oH, oW}) : std::vector<Nd4jLong>({bS, oH, oW, mC*iC}); | ||||||
|         auto resultFF  = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext()); |         auto resultFF  = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext()); | ||||||
|         ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); |         ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|         auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW,  0,indIOioC,indIiH,indIiH+1}); |         auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|         auto gradIDepth  = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext());                 // [bS, oH, oW, iC*mC]  (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
 |         auto gradIDepth  = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext());                 // [bS, oH, oW, iC*mC]  (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
 | ||||||
| 
 | 
 | ||||||
|         ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW);    // in this case oH=iH and oW=iW
 |         ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW, wFormat);    // in this case oH=iH and oW=iW
 | ||||||
| 
 | 
 | ||||||
|         gradO = gradIDepth; |         gradO = gradIDepth; | ||||||
|         bias = gradB = nullptr;                     // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step
 |         bias = gradB = nullptr;                     // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step
 | ||||||
| @ -288,7 +292,7 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // ----- apply depthwise_conv2d_bp ----- //
 |     // ----- apply depthwise_conv2d_bp ----- //
 | ||||||
|     ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); |     ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     if(weightsPoint) |     if(weightsPoint) | ||||||
|         delete gradO; |         delete gradO; | ||||||
| @ -301,8 +305,8 @@ DECLARE_SHAPE_FN(sconv2d_bp) { | |||||||
| 
 | 
 | ||||||
|     auto inputShapeInfo    = inputShape->at(0);                 // [bS, iH, iW, iC]  (NHWC) or [bS, iC, iH, iW]  (NCHW)
 |     auto inputShapeInfo    = inputShape->at(0);                 // [bS, iH, iW, iC]  (NHWC) or [bS, iC, iH, iW]  (NCHW)
 | ||||||
|     auto gradOShapeInfo    = inputShape->at(1);                 // [bS, oH, oW, oC]  (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 |     auto gradOShapeInfo    = inputShape->at(1);                 // [bS, oH, oW, oC]  (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 | ||||||
|     auto weightsDShapeInfo = inputShape->at(2);                 // [kH, kW, iC, mC]  always
 |     auto weightsDShapeInfo = inputShape->at(2);                 // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|     Nd4jLong* weightsPShapeInfo = nullptr;                      // [1, 1, iC*mC, oC] always
 |     Nd4jLong* weightsPShapeInfo = nullptr;                      // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
 | ||||||
|     Nd4jLong* biasShapeInfo     = nullptr;                      // [oC], oC = iC*mC if weightsPoint=nullptr
 |     Nd4jLong* biasShapeInfo     = nullptr;                      // [oC], oC = iC*mC if weightsPoint=nullptr
 | ||||||
| 
 | 
 | ||||||
|     if(block.width() == 4) { |     if(block.width() == 4) { | ||||||
| @ -335,8 +339,9 @@ DECLARE_SHAPE_FN(sconv2d_bp) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW     = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 |     int isNCHW     = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int indIOioC, indIiH, indWmC(3); |     int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); | ||||||
|     if(!isNCHW) { |     if(!isNCHW) { | ||||||
|         indIOioC = 3; indIiH = 1; |         indIOioC = 3; indIiH = 1; | ||||||
|     } |     } | ||||||
| @ -356,10 +361,10 @@ DECLARE_SHAPE_FN(sconv2d_bp) { | |||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShapeInfo = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indIiH,indIiH+1}); |     std::vector<Nd4jLong> expectedGradOShapeInfo = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShapeInfo), 0, "SCONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShapeInfo).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShapeInfo), 0, "SCONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShapeInfo).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); | ||||||
|     std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC}; |     std::vector<Nd4jLong> expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); | ||||||
|     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D_BP OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); |     REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D_BP OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); | ||||||
|     if(weightsPShapeInfo) { |     if(weightsPShapeInfo) { | ||||||
|         std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC}; |         std::vector<Nd4jLong> expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); | ||||||
|         REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D_BP OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); |         REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D_BP OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); | ||||||
|     } |     } | ||||||
|     if (biasShapeInfo) |     if (biasShapeInfo) | ||||||
|  | |||||||
| @ -166,7 +166,7 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW,  0,indIOioC,indIiH,indIiH+1}); |     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW,  0,indIOioC,indIiH,indIiH+1}); |     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|  | |||||||
| @ -55,7 +55,7 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|     REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); |     REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); | ||||||
| @ -172,7 +172,7 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|  | |||||||
| @ -168,7 +168,7 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW,  0,indIOioC,indIiH,indIiH+1}); |     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW,  0,indIOioC,indIiH,indIiH+1}); |     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|  | |||||||
| @ -55,7 +55,7 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|     REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); |     REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); | ||||||
| @ -174,7 +174,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|  | |||||||
| @ -167,7 +167,7 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW,  0,indIOioC,indIiH,indIiH+1}); |     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW,  0,indIOioC,indIiH,indIiH+1}); |     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|  | |||||||
| @ -154,15 +154,24 @@ namespace sd { | |||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             // evaluates sizes values and indexes using input and output arrays depending on data format
 |             // evaluates sizes values and indexes using input and output arrays depending on data format
 | ||||||
|             static inline void getSizesAndIndexesConv2d(const bool isNCHW, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) { |             static inline void getSizesAndIndexesConv2d(const bool isNCHW, const int wFormat, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) { | ||||||
|                 getSizesAndIndexesConv2d(isNCHW, input.getShapeInfo(), output.getShapeInfo(), bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |                 getSizesAndIndexesConv2d(isNCHW, wFormat, input.getShapeInfo(), output.getShapeInfo(), bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             static inline void getSizesAndIndexesConv2d(const bool isNCHW, const Nd4jLong* inShapeInfo, const Nd4jLong* outShapeInfo, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) { |             static inline void getSizesAndIndexesConv2d(const bool isNCHW, const int wFormat, const Nd4jLong* inShapeInfo, const Nd4jLong* outShapeInfo, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) { | ||||||
|                 // input   [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |                 // input   [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|                 // weights [kH, kW, iC, oC] always
 |                 // weights [kH, kW, iC, oC] (wFormat = 0), [oC, iC, kH, kW] (wFormat = 1), [oC, kH, kW, iC] (wFormat = 2)
 | ||||||
|                 // output  [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 |                 // output  [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 | ||||||
|                 indWkH = 0; indWiC = 2; indWoC = 3; | 
 | ||||||
|  |                 if(0 == wFormat) { | ||||||
|  |                     indWkH = 0; indWiC = 2; indWoC = 3; | ||||||
|  |                 } | ||||||
|  |                 else if(1 == wFormat) { | ||||||
|  |                     indWkH = 2; indWiC = 1; indWoC = 0; | ||||||
|  |                 } | ||||||
|  |                 else { | ||||||
|  |                     indWkH = 1; indWiC = 3; indWoC = 0; | ||||||
|  |                 } | ||||||
| 
 | 
 | ||||||
|                 if(!isNCHW) { |                 if(!isNCHW) { | ||||||
|                     indIOioC = 3; indIiH = 1; indOoH = 1; |                     indIOioC = 3; indIiH = 1; indOoH = 1; | ||||||
| @ -181,12 +190,21 @@ namespace sd { | |||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             // evaluates sizes values and indexes using input and output arrays depending on data format
 |             // evaluates sizes values and indexes using input and output arrays depending on data format
 | ||||||
|             static inline void getSizesAndIndexesConv3d(const bool isNCDHW, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iD, int& iH, int& iW, int& oC, int& oD, int& oH, int& oW, int& indIOioC, int& indIOioD, int& indWiC, int& indWoC, int& indWkD) { |             static inline void getSizesAndIndexesConv3d(const bool isNCDHW, const int wFormat, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iD, int& iH, int& iW, int& oC, int& oD, int& oH, int& oW, int& indIOioC, int& indIOioD, int& indWiC, int& indWoC, int& indWkD) { | ||||||
|                 // input   [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 |                 // input   [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 | ||||||
|                 // weights [kD, kH, kW, iC, oC] (NDHWC) or [oC, iC, kD, kH, kW] (NCDHW)
 |                 // weights [kD, kH, kW, iC, oC] (wFormat = 0), [oC, iC, kD, kH, kW] (wFormat = 1), [oC, kD, kH, kW, iC] (wFormat = 2)
 | ||||||
|                 // output  [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
 |                 // output  [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
 | ||||||
| 
 | 
 | ||||||
|                 indWkD = 0; indWiC = 3; indWoC = 4; |                 if(0 == wFormat) { | ||||||
|  |                     indWkD = 0; indWiC = 3; indWoC = 4; | ||||||
|  |                 } | ||||||
|  |                 else if(1 == wFormat) { | ||||||
|  |                     indWkD = 2; indWiC = 1; indWoC = 0; | ||||||
|  |                 } | ||||||
|  |                 else { | ||||||
|  |                     indWkD = 1; indWiC = 4; indWoC = 0; | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|                 if(!isNCDHW) { |                 if(!isNCDHW) { | ||||||
|                     indIOioC = 4; indIOioD = 1; |                     indIOioC = 4; indIOioD = 1; | ||||||
|                 } |                 } | ||||||
| @ -203,7 +221,6 @@ namespace sd { | |||||||
|                 oD = output.sizeAt(indIOioD);                  // output depth
 |                 oD = output.sizeAt(indIOioD);                  // output depth
 | ||||||
|                 oH = output.sizeAt(indIOioD+1);                // output height
 |                 oH = output.sizeAt(indIOioD+1);                // output height
 | ||||||
|                 oW = output.sizeAt(indIOioD+2);                // output width
 |                 oW = output.sizeAt(indIOioD+2);                // output width
 | ||||||
| 
 |  | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             // static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int paddingMode, int& pH, int& pW, int& dH, int& dW) {
 |             // static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int paddingMode, int& pH, int& pW, int& dH, int& dW) {
 | ||||||
| @ -254,19 +271,41 @@ namespace sd { | |||||||
|             //     }
 |             //     }
 | ||||||
|             // }
 |             // }
 | ||||||
| 
 | 
 | ||||||
|             static void conv2d(sd::graph::Context  &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); |             static std::vector<Nd4jLong> expectWeightsShape(const int wFormat, const int kH, const int kW, const int iC, const int oC) { | ||||||
|  | 
 | ||||||
|  |                 if(0 == wFormat) | ||||||
|  |                     return std::vector<Nd4jLong>({kH, kW, iC, oC}); | ||||||
|  | 
 | ||||||
|  |                 if(1 == wFormat) | ||||||
|  |                     return std::vector<Nd4jLong>({oC, iC, kH, kW}); | ||||||
|  | 
 | ||||||
|  |                 return std::vector<Nd4jLong>({oC, kH, kW, iC}); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             static std::vector<Nd4jLong> expectWeightsShape(const int wFormat, const int kD, const int kH, const int kW, const int iC, const int oC) { | ||||||
|  | 
 | ||||||
|  |                 if(0 == wFormat) | ||||||
|  |                     return std::vector<Nd4jLong>({kD, kH, kW, iC, oC}); | ||||||
|  | 
 | ||||||
|  |                 if(1 == wFormat) | ||||||
|  |                     return std::vector<Nd4jLong>({oC, iC, kD, kH, kW}); | ||||||
|  | 
 | ||||||
|  |                 return std::vector<Nd4jLong>({oC, kD, kH, kW, iC}); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             static void conv2d(sd::graph::Context  &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); | ||||||
| 
 | 
 | ||||||
|             // static void conv2d(sd::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
 |             // static void conv2d(sd::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
 | ||||||
| 
 | 
 | ||||||
|             // static void conv2dBP(sd::graph::Context & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs);
 |             // static void conv2dBP(sd::graph::Context & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs);
 | ||||||
| 
 | 
 | ||||||
|             static void conv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); |             static void conv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); | ||||||
| 
 | 
 | ||||||
|             static void depthwiseConv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); |             static void depthwiseConv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); | ||||||
| 
 | 
 | ||||||
|             static void depthwiseConv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); |             static void depthwiseConv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); | ||||||
| 
 | 
 | ||||||
|             static void sconv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias,  NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); |             static void sconv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias,  NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); | ||||||
| 
 | 
 | ||||||
|             static void vol2col(sd::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); |             static void vol2col(sd::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -258,10 +258,10 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
|         template <typename X, typename Y> |         template <typename X, typename Y> | ||||||
|         static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { |         static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|             // input   [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |             // input   [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|             // weights [kH, kW, iC, oC] always
 |             // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|             // bias    [oC]
 |             // bias    [oC]
 | ||||||
|             // output  [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 |             // output  [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 | ||||||
| 
 | 
 | ||||||
| @ -278,7 +278,7 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
|             int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |             int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|             int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |             int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|             ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |             ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|             ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); |             ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); | ||||||
| 
 | 
 | ||||||
| @ -291,6 +291,14 @@ namespace sd { | |||||||
|             else |             else | ||||||
|                 input = new NDArray(input->permute({0, 3, 1, 2}));                         // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC
 |                 input = new NDArray(input->permute({0, 3, 1, 2}));                         // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC
 | ||||||
| 
 | 
 | ||||||
|  |             std::vector<int> wAxes; | ||||||
|  |             if(0 == wFormat) | ||||||
|  |                 wAxes = {0, 1, 2}; | ||||||
|  |             else if(1 == wFormat) | ||||||
|  |                 wAxes = {2, 3, 1}; | ||||||
|  |             else | ||||||
|  |                 wAxes = {1, 2, 3}; | ||||||
|  | 
 | ||||||
|             NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); |             NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); | ||||||
|             NDArray colP = col.permute({0, 5, 3, 4, 1, 2});            // {bS, iC, kH, kW, oH, oW}
 |             NDArray colP = col.permute({0, 5, 3, 4, 1, 2});            // {bS, iC, kH, kW, oH, oW}
 | ||||||
|             NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); |             NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); | ||||||
| @ -298,7 +306,7 @@ namespace sd { | |||||||
|             //----- calculation of output -----//
 |             //----- calculation of output -----//
 | ||||||
|             auto ctx = block.launchContext(); |             auto ctx = block.launchContext(); | ||||||
|             helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
 |             helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
 | ||||||
|             MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, {0,1,2}, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC]
 |             MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC]
 | ||||||
| 
 | 
 | ||||||
|             //----- assign outTemp to output  -----//
 |             //----- assign outTemp to output  -----//
 | ||||||
|             if(isNCHW) { |             if(isNCHW) { | ||||||
| @ -319,15 +327,15 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
|         template <typename X, typename Y> |         template <typename X, typename Y> | ||||||
|         static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { |         static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|             // input   [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |             // input   [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|             // weights [kH, kW, iC, oC] always
 |             // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|             // bias    [oC]
 |             // bias    [oC]
 | ||||||
|             // gradO   [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 |             // gradO   [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|             // gradI    [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 |             // gradI    [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 | ||||||
|             // gradW    [kH, kW, iC, oC] always
 |             // gradW    [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|             // gradB    [oC]
 |             // gradB    [oC]
 | ||||||
| 
 | 
 | ||||||
|             // kH         filter(kernel) height
 |             // kH         filter(kernel) height
 | ||||||
| @ -343,7 +351,7 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
|             int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |             int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|             int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |             int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|             ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |             ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|             ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); |             ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); | ||||||
| 
 | 
 | ||||||
| @ -359,13 +367,28 @@ namespace sd { | |||||||
|                 gradOaxesForDot  = {0, 2, 3};                                           // bS, oH, oW
 |                 gradOaxesForDot  = {0, 2, 3};                                           // bS, oH, oW
 | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|  |             std::vector<int> wPermut, colPermut; | ||||||
|  | 
 | ||||||
|  |             if(0 == wFormat) { | ||||||
|  |                 wPermut   = {2, 0, 1, 3}; | ||||||
|  |                 colPermut = {2, 3, 1, 0, 4, 5}; | ||||||
|  |             } | ||||||
|  |             else if(1 == wFormat) { | ||||||
|  |                 wPermut   = {1, 2, 3, 0}; | ||||||
|  |                 colPermut = {1, 2, 3, 0, 4, 5}; | ||||||
|  |             } | ||||||
|  |             else { | ||||||
|  |                 wPermut   = {3, 1, 2, 0}; | ||||||
|  |                 colPermut = {2, 3, 1, 0, 4, 5}; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|             NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); |             NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); | ||||||
| 
 | 
 | ||||||
|             // ----- calculation of gradW ----- //
 |             // ----- calculation of gradW ----- //
 | ||||||
|             if(gradW) { |             if(gradW) { | ||||||
|                 auto ctx = block.launchContext(); |                 auto ctx = block.launchContext(); | ||||||
|                 helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));   // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
 |                 helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));   // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
 | ||||||
|                 sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, {2, 0, 1, 3});       // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC]
 |                 sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut);       // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC]
 | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             // ----- calculation of gradB ----- //
 |             // ----- calculation of gradB ----- //
 | ||||||
| @ -379,9 +402,12 @@ namespace sd { | |||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             //----- calculation of gradI -----//
 |             //----- calculation of gradI -----//
 | ||||||
|             sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2, 3, 1, 0, 4, 5});  // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
 |             // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
 | ||||||
|  |             // [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW]
 | ||||||
|  |             // [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
 | ||||||
|  |             sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); | ||||||
| 
 | 
 | ||||||
|             helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW);                          // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
 |             helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW);       // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
 | ||||||
| 
 | 
 | ||||||
|             if(!isNCHW) { |             if(!isNCHW) { | ||||||
|                 delete input; |                 delete input; | ||||||
| @ -391,10 +417,10 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
|         template <typename X, typename Y> |         template <typename X, typename Y> | ||||||
|         static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { |         static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|             // input     [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |             // input     [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|             // weights   [kH, kW, iC, mC] always
 |             // weights   [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|             // bias      [oC] = iC*mC
 |             // bias      [oC] = iC*mC
 | ||||||
|             // output    [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
 |             // output    [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
 | ||||||
| 
 | 
 | ||||||
| @ -411,23 +437,30 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
|             int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
 |             int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
 | ||||||
|             int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 |             int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 | ||||||
|             ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |             ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|             mC = weights->sizeAt(indWmC);                           // channels multiplier
 |             mC = weights->sizeAt(indWmC);                           // channels multiplier
 | ||||||
| 
 | 
 | ||||||
|             std::vector<std::vector<Nd4jLong>> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}};  // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW]
 |             std::vector<std::vector<Nd4jLong>> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}};  // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW]
 | ||||||
|             std::vector<std::vector<Nd4jLong>> modifOutput; |             std::vector<std::vector<Nd4jLong>> modifOutput, modifWeights; | ||||||
|             std::vector<Nd4jLong> outReShape; |             std::vector<Nd4jLong> outReShape; | ||||||
| 
 | 
 | ||||||
|             if(!isNCHW) { |             if(!isNCHW) { | ||||||
|                 outReShape = {bS, oH, oW, iC, mC};                                              // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC]
 |                 outReShape = {bS, oH, oW, iC, mC};                                              // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC]
 | ||||||
|                 modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}};                                 // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
 |                 modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}};                                 // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
 | ||||||
|                 input = new NDArray(input->permute({0, 3, 1, 2}));                             // [bS,iH,iW,iC]    -> [bS,iC,iH,iW]
 |                 input = new NDArray(input->permute({0, 3, 1, 2}));                              // [bS,iH,iW,iC]    -> [bS,iC,iH,iW]
 | ||||||
|             } |             } | ||||||
|             else { |             else { | ||||||
|                 outReShape = {bS, iC, mC, oH, oW};                                              // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW]
 |                 outReShape = {bS, iC, mC, oH, oW};                                              // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW]
 | ||||||
|                 modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}};                                 // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
 |                 modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}};                                 // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
 | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|  |             if(0 == wFormat) | ||||||
|  |                 modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; | ||||||
|  |             else if(1 == wFormat) | ||||||
|  |                 modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; | ||||||
|  |             else | ||||||
|  |                 modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; | ||||||
|  | 
 | ||||||
|             if(paddingMode == 1)                       // SAME
 |             if(paddingMode == 1)                       // SAME
 | ||||||
|                 ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); |                 ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); | ||||||
| 
 | 
 | ||||||
| @ -435,7 +468,7 @@ namespace sd { | |||||||
|             NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); |             NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); | ||||||
| 
 | 
 | ||||||
|             helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
 |             helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
 | ||||||
|             MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput);              // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
 |             MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput);              // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
 | ||||||
| 
 | 
 | ||||||
|             if(bias) |             if(bias) | ||||||
|                 // output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
 |                 // output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
 | ||||||
| @ -447,14 +480,14 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
|         template <typename X, typename Y> |         template <typename X, typename Y> | ||||||
|         static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { |         static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|             // input    [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
 |             // input    [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
 | ||||||
|             // weights  [kH, kW, iC, mC] always
 |             // weights  [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|             // bias     [oC] = [iC*mC]
 |             // bias     [oC] = [iC*mC]
 | ||||||
|             // gradO    [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 |             // gradO    [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 | ||||||
|             // gradI    [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
 |             // gradI    [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
 | ||||||
|             // gradW    [kH, kW, iC, mC] always
 |             // gradW    [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|             // gradB    [oC]
 |             // gradB    [oC]
 | ||||||
| 
 | 
 | ||||||
|             //  kH          filter(kernel) height
 |             //  kH          filter(kernel) height
 | ||||||
| @ -470,19 +503,19 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
|             int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
 |             int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
 | ||||||
|             int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 |             int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 | ||||||
|             ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |             ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|             mC = weights->sizeAt(indWmC);                           // channels multiplier
 |             mC = weights->sizeAt(indWmC);                           // channels multiplier
 | ||||||
| 
 | 
 | ||||||
|             std::vector<std::vector<Nd4jLong>> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}};      // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW]
 |             std::vector<std::vector<Nd4jLong>> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}};      // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW]
 | ||||||
|             std::vector<std::vector<Nd4jLong>> modifGradO1, modifGradO2; |             std::vector<std::vector<Nd4jLong>> modifGradO1, modifGradO2, modifWeights; | ||||||
|             std::vector<Nd4jLong> gradOreShape; |             std::vector<Nd4jLong> gradOreShape; | ||||||
| 
 | 
 | ||||||
|             if(!isNCHW) { |             if(!isNCHW) { | ||||||
|                 gradOreShape = {bS, oH, oW, iC, mC};                                            // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC]
 |                 gradOreShape = {bS, oH, oW, iC, mC};                                            // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC]
 | ||||||
|                 modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}};                                 // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
 |                 modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}};                                 // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
 | ||||||
|                 modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}};                                   // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
 |                 modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}};                                   // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
 | ||||||
|                 input = new NDArray(input->permute({0, 3, 1, 2}));                             // [bS,iH,iW,iC]    -> [bS,iC,iH,iW]
 |                 input = new NDArray(input->permute({0, 3, 1, 2}));                              // [bS,iH,iW,iC]    -> [bS,iC,iH,iW]
 | ||||||
|                 gradI = new NDArray(gradI->permute({0, 3, 1, 2}));                             // [bS,iH,iW,iC]    -> [bS,iC,iH,iW]
 |                 gradI = new NDArray(gradI->permute({0, 3, 1, 2}));                              // [bS,iH,iW,iC]    -> [bS,iC,iH,iW]
 | ||||||
|             } |             } | ||||||
|             else { |             else { | ||||||
|                 gradOreShape = {bS, iC, mC, oH, oW};                                            // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW]
 |                 gradOreShape = {bS, iC, mC, oH, oW};                                            // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW]
 | ||||||
| @ -490,6 +523,13 @@ namespace sd { | |||||||
|                 modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}};                                   // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
 |                 modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}};                                   // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
 | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|  |             if(0 == wFormat) | ||||||
|  |                 modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; | ||||||
|  |             else if(1 == wFormat) | ||||||
|  |                 modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; | ||||||
|  |             else | ||||||
|  |                 modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; | ||||||
|  | 
 | ||||||
|             if(paddingMode == 1)                       // SAME
 |             if(paddingMode == 1)                       // SAME
 | ||||||
|                 ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); |                 ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); | ||||||
| 
 | 
 | ||||||
| @ -499,7 +539,7 @@ namespace sd { | |||||||
|             // ----- calculation of gradW and gradB ----- //
 |             // ----- calculation of gradW and gradB ----- //
 | ||||||
| 
 | 
 | ||||||
|             helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
 |             helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
 | ||||||
|             sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, {{2,0,1,3},{iC,kH*kW,mC}});  // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC]
 |             sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights);  // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC]
 | ||||||
| 
 | 
 | ||||||
|             // ----- calculation of gradB ----- //
 |             // ----- calculation of gradB ----- //
 | ||||||
|             if(gradB) { |             if(gradB) { | ||||||
| @ -513,8 +553,8 @@ namespace sd { | |||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             //----- calculation of gradI -----//
 |             //----- calculation of gradI -----//
 | ||||||
|             sd::MmulHelper::tensorDot(weights, gradO, &columns, {{2,0,1,3},{iC,kH*kW,mC}}, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW]
 |             sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW]
 | ||||||
|             helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW);                                       // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
 |             helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW);       // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
 | ||||||
| 
 | 
 | ||||||
|             if(!isNCHW) { |             if(!isNCHW) { | ||||||
|                 delete input; |                 delete input; | ||||||
| @ -524,11 +564,11 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
|         template <typename X, typename Y> |         template <typename X, typename Y> | ||||||
|         static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias,  NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { |         static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias,  NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|             // input         [bS, iH, iW, iC]  (NHWC) or [bS, iC, iH, iW]  (NCHW)
 |             // input         [bS, iH, iW, iC]  (NHWC) or [bS, iC, iH, iW]  (NCHW)
 | ||||||
|             // weightsDepth  [kH, kW, iC, mC]  always
 |             // weightsDepth  [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|             // weightsPoint  [1, 1, iC*mC, oC] always
 |             // weightsPoint  [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC]
 | ||||||
|             // bias          [oC], oC = iC*mC if weightsPoint=nullptr
 |             // bias          [oC], oC = iC*mC if weightsPoint=nullptr
 | ||||||
|             // output is     [bS, oH, oW, oC]  (NHWC) or [bS, oC, oH, oW]  (NCHW)
 |             // output is     [bS, oH, oW, oC]  (NHWC) or [bS, oC, oH, oW]  (NCHW)
 | ||||||
| 
 | 
 | ||||||
| @ -545,7 +585,7 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
|             int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
 |             int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
 | ||||||
|             int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 |             int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 | ||||||
|             ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |             ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|             mC = weightsDepth->sizeAt(indWmC);                      // channels multiplier
 |             mC = weightsDepth->sizeAt(indWmC);                      // channels multiplier
 | ||||||
| 
 | 
 | ||||||
|             NDArray* outputDepth = output; |             NDArray* outputDepth = output; | ||||||
| @ -553,11 +593,11 @@ namespace sd { | |||||||
|                 outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); |                 outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); | ||||||
| 
 | 
 | ||||||
|             // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- //
 |             // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- //
 | ||||||
|             ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW); |             ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|             // ----- perform pointwise convolution (oH = iH, oW = iW) ----- //
 |             // ----- perform pointwise convolution (oH = iH, oW = iW) ----- //
 | ||||||
|             if (weightsPoint) { |             if (weightsPoint) { | ||||||
|                 ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW);             // in this case oH=iH, oW=iW
 |                 ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat);             // in this case oH=iH, oW=iW
 | ||||||
|                 delete outputDepth; |                 delete outputDepth; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @ -1772,20 +1812,20 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|         void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { |         void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
|             BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); |             BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); | ||||||
|         } |         } | ||||||
|         void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { |         void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
|             BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); |             BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); | ||||||
|         } |         } | ||||||
|         void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { |         void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
|             BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); |             BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); | ||||||
|         } |         } | ||||||
|         void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { |         void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
|             BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); |             BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); | ||||||
|         } |         } | ||||||
|         void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias,  NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { |         void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias,  NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
|             BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); |             BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); | ||||||
|         } |         } | ||||||
|         void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { |         void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { | ||||||
|             BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); |             BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); | ||||||
|  | |||||||
| @ -217,10 +217,10 @@ void ConvolutionUtils::col2vol(sd::graph::Context& block, const NDArray& col, ND | |||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| template <typename X, typename Y> | template <typename X, typename Y> | ||||||
| static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { | static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // input   [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) |     // input   [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) | ||||||
|     // weights [kH, kW, iC, oC] always |     // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] | ||||||
|     // bias    [oC] |     // bias    [oC] | ||||||
|     // output  [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) |     // output  [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) | ||||||
| 
 | 
 | ||||||
| @ -237,7 +237,7 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); |     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); | ||||||
| 
 | 
 | ||||||
| @ -248,6 +248,14 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr | |||||||
|     else |     else | ||||||
|         input = new NDArray(input->permute({0, 3, 1, 2}));                         // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC |         input = new NDArray(input->permute({0, 3, 1, 2}));                         // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC | ||||||
| 
 | 
 | ||||||
|  |     std::vector<int> wAxes; | ||||||
|  |     if(0 == wFormat) | ||||||
|  |         wAxes = {0, 1, 2}; | ||||||
|  |     else if(1 == wFormat) | ||||||
|  |         wAxes = {2, 3, 1}; | ||||||
|  |     else | ||||||
|  |         wAxes = {1, 2, 3}; | ||||||
|  | 
 | ||||||
|     NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); |     NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); | ||||||
|     NDArray colP = col.permute({0, 5, 3, 4, 1, 2});            // {bS, iC, kH, kW, oH, oW} |     NDArray colP = col.permute({0, 5, 3, 4, 1, 2});            // {bS, iC, kH, kW, oH, oW} | ||||||
|     NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); |     NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); | ||||||
| @ -255,7 +263,7 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr | |||||||
|     //----- calculation of output -----// |     //----- calculation of output -----// | ||||||
|     auto ctx = block.launchContext(); |     auto ctx = block.launchContext(); | ||||||
|     helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] |     helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] | ||||||
|     MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, {0,1,2}, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] |     MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] | ||||||
| 
 | 
 | ||||||
|     //----- assign outTemp to output  -----// |     //----- assign outTemp to output  -----// | ||||||
|     if(isNCHW) { |     if(isNCHW) { | ||||||
| @ -275,16 +283,16 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { | void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
|     BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); |     BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| template <typename X, typename Y> | template <typename X, typename Y> | ||||||
| static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { | static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // input     [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) |     // input     [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) | ||||||
|     // weights   [kH, kW, iC, mC] always |     // weights   [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] | ||||||
|     // bias      [oC] = iC*mC |     // bias      [oC] = iC*mC | ||||||
|     // output    [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) |     // output    [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) | ||||||
| 
 | 
 | ||||||
| @ -301,23 +309,30 @@ static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, co | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width |     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weights->sizeAt(indWmC);                           // channels multiplier |     mC = weights->sizeAt(indWmC);                           // channels multiplier | ||||||
| 
 | 
 | ||||||
|     std::vector<std::vector<Nd4jLong>> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}};  // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] |     std::vector<std::vector<Nd4jLong>> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}};  // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] | ||||||
|     std::vector<std::vector<Nd4jLong>> modifOutput; |     std::vector<std::vector<Nd4jLong>> modifOutput, modifWeights; | ||||||
|     std::vector<Nd4jLong> outReShape; |     std::vector<Nd4jLong> outReShape; | ||||||
| 
 | 
 | ||||||
|     if(!isNCHW) { |     if(!isNCHW) { | ||||||
|         outReShape = {bS, oH, oW, iC, mC};                                              // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] |         outReShape = {bS, oH, oW, iC, mC};                                              // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] | ||||||
|         modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}};                                 // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] |         modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}};                                 // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] | ||||||
|         input = new NDArray(input->permute({0, 3, 1, 2}));                             // [bS,iH,iW,iC]    -> [bS,iC,iH,iW] |         input = new NDArray(input->permute({0, 3, 1, 2}));                              // [bS,iH,iW,iC]    -> [bS,iC,iH,iW] | ||||||
|     } |     } | ||||||
|     else { |     else { | ||||||
|         outReShape = {bS, iC, mC, oH, oW};                                              // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] |         outReShape = {bS, iC, mC, oH, oW};                                              // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] | ||||||
|         modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}};                                 // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] |         modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}};                                 // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     if(0 == wFormat) | ||||||
|  |         modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; | ||||||
|  |     else if(1 == wFormat) | ||||||
|  |         modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; | ||||||
|  |     else | ||||||
|  |         modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; | ||||||
|  | 
 | ||||||
|     if(paddingMode == 1)                       // SAME |     if(paddingMode == 1)                       // SAME | ||||||
|         ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); |         ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); | ||||||
| 
 | 
 | ||||||
| @ -325,7 +340,7 @@ static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, co | |||||||
|     NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); |     NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); | ||||||
| 
 | 
 | ||||||
|     helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] |     helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] | ||||||
|     MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput);              // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] |     MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput);              // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] | ||||||
| 
 | 
 | ||||||
|     if(bias) |     if(bias) | ||||||
|         // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); |         // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); | ||||||
| @ -336,17 +351,17 @@ static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, co | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { | void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
|     BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); |     BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| template <typename X, typename Y> | template <typename X, typename Y> | ||||||
| static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias,  NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { | static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias,  NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // input         [bS, iH, iW, iC]  (NHWC) or [bS, iC, iH, iW]  (NCHW) |     // input         [bS, iH, iW, iC]  (NHWC) or [bS, iC, iH, iW]  (NCHW) | ||||||
|     // weightsDepth  [kH, kW, iC, mC]  always |     // weightsDepth  [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] | ||||||
|     // weightsPoint  [1, 1, iC*mC, oC] always |     // weightsPoint  [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] | ||||||
|     // bias          [oC], oC = iC*mC if weightsPoint=nullptr |     // bias          [oC], oC = iC*mC if weightsPoint=nullptr | ||||||
|     // output is     [bS, oH, oW, oC]  (NHWC) or [bS, oC, oH, oW]  (NCHW) |     // output is     [bS, oH, oW, oC]  (NHWC) or [bS, oC, oH, oW]  (NCHW) | ||||||
| 
 | 
 | ||||||
| @ -363,7 +378,7 @@ static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDAr | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier, output channels, output height/width |     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier, output channels, output height/width | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weightsDepth->sizeAt(indWmC);                      // channels multiplier |     mC = weightsDepth->sizeAt(indWmC);                      // channels multiplier | ||||||
| 
 | 
 | ||||||
|     NDArray* outputDepth = output; |     NDArray* outputDepth = output; | ||||||
| @ -371,18 +386,18 @@ static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDAr | |||||||
|         outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); |         outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); | ||||||
| 
 | 
 | ||||||
|     // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // |     // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // | ||||||
|     ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW); |     ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // |     // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // | ||||||
|     if (weightsPoint) { |     if (weightsPoint) { | ||||||
|         ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW);             // in this case oH=iH, oW=iW |         ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat);             // in this case oH=iH, oW=iW | ||||||
|         delete outputDepth; |         delete outputDepth; | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias,  NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { | void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias,  NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
|     BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); |     BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| @ -1176,15 +1191,15 @@ void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, const NDArray& inp | |||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| template <typename X, typename Y> | template <typename X, typename Y> | ||||||
| static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { | static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // input   [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) |     // input   [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) | ||||||
|     // weights [kH, kW, iC, oC] always |     // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] | ||||||
|     // bias    [oC] |     // bias    [oC] | ||||||
|     // gradO   [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next |     // gradO   [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next | ||||||
| 
 | 
 | ||||||
|     // gradI    [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon |     // gradI    [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon | ||||||
|     // gradW    [kH, kW, iC, oC] always |     // gradW    [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] | ||||||
|     // gradB    [oC] |     // gradB    [oC] | ||||||
| 
 | 
 | ||||||
|     // kH         filter(kernel) height |     // kH         filter(kernel) height | ||||||
| @ -1200,7 +1215,7 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); |     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); | ||||||
| 
 | 
 | ||||||
| @ -1214,13 +1229,27 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA | |||||||
|         gradOaxesForDot  = {0, 2, 3};                                           // bS, oH, oW |         gradOaxesForDot  = {0, 2, 3};                                           // bS, oH, oW | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     std::vector<int> wPermut, colPermut; | ||||||
|  |     if(0 == wFormat) { | ||||||
|  |         wPermut   = {2, 0, 1, 3}; | ||||||
|  |         colPermut = {2, 3, 1, 0, 4, 5}; | ||||||
|  |     } | ||||||
|  |     else if(1 == wFormat) { | ||||||
|  |         wPermut   = {1, 2, 3, 0}; | ||||||
|  |         colPermut = {1, 2, 3, 0, 4, 5}; | ||||||
|  |     } | ||||||
|  |     else { | ||||||
|  |         wPermut   = {3, 1, 2, 0}; | ||||||
|  |         colPermut = {2, 3, 1, 0, 4, 5}; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); |     NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); | ||||||
| 
 | 
 | ||||||
|     // ----- calculation of gradW ----- // |     // ----- calculation of gradW ----- // | ||||||
|     if(gradW) { |     if(gradW) { | ||||||
|         auto ctx = block.launchContext(); |         auto ctx = block.launchContext(); | ||||||
|         helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));   // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] |         helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));   // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] | ||||||
|         sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, {2, 0, 1, 3});       // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] |         sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut);       // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // ----- calculation of gradB ----- // |     // ----- calculation of gradB ----- // | ||||||
| @ -1234,7 +1263,10 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     //----- calculation of gradI -----// |     //----- calculation of gradI -----// | ||||||
|     sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2, 3, 1, 0, 4, 5});  // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] |     // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] | ||||||
|  |     // [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW] | ||||||
|  |     // [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] | ||||||
|  |     sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut);  // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] | ||||||
| 
 | 
 | ||||||
|     helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW);                          // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] |     helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW);                          // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] | ||||||
| 
 | 
 | ||||||
| @ -1245,20 +1277,20 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { | void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
|     BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); |     BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| template <typename X, typename Y> | template <typename X, typename Y> | ||||||
| static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { | static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // input    [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) |     // input    [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) | ||||||
|     // weights  [kH, kW, iC, mC] always |     // weights  [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] | ||||||
|     // bias     [oC] = [iC*mC] |     // bias     [oC] = [iC*mC] | ||||||
|     // gradO    [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next |     // gradO    [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next | ||||||
|     // gradI    [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon |     // gradI    [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon | ||||||
|     // gradW    [kH, kW, iC, mC] always |     // gradW    [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] | ||||||
|     // gradB    [oC] |     // gradB    [oC] | ||||||
| 
 | 
 | ||||||
|     //  kH          filter(kernel) height |     //  kH          filter(kernel) height | ||||||
| @ -1274,11 +1306,11 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width |     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weights->sizeAt(indWmC);                           // channels multiplier |     mC = weights->sizeAt(indWmC);                           // channels multiplier | ||||||
| 
 | 
 | ||||||
|     std::vector<std::vector<Nd4jLong>> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}};      // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] |     std::vector<std::vector<Nd4jLong>> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}};      // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] | ||||||
|     std::vector<std::vector<Nd4jLong>> modifGradO1, modifGradO2; |     std::vector<std::vector<Nd4jLong>> modifGradO1, modifGradO2, modifWeights; | ||||||
|     std::vector<Nd4jLong> gradOreShape; |     std::vector<Nd4jLong> gradOreShape; | ||||||
| 
 | 
 | ||||||
|     if(!isNCHW) { |     if(!isNCHW) { | ||||||
| @ -1294,6 +1326,13 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con | |||||||
|         modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}};                                   // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] |         modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}};                                   // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     if(0 == wFormat) | ||||||
|  |         modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; | ||||||
|  |     else if(1 == wFormat) | ||||||
|  |         modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; | ||||||
|  |     else | ||||||
|  |         modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; | ||||||
|  | 
 | ||||||
|     if(paddingMode == 1)                       // SAME |     if(paddingMode == 1)                       // SAME | ||||||
|         ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); |         ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); | ||||||
| 
 | 
 | ||||||
| @ -1303,7 +1342,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con | |||||||
|     // ----- calculation of gradW and gradB ----- // |     // ----- calculation of gradW and gradB ----- // | ||||||
| 
 | 
 | ||||||
|     helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] |     helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] | ||||||
|     sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, {{2,0,1,3},{iC,kH*kW,mC}});  // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] |     sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights);  // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] | ||||||
| 
 | 
 | ||||||
|     // ----- calculation of gradB ----- // |     // ----- calculation of gradB ----- // | ||||||
|     if(gradB) { |     if(gradB) { | ||||||
| @ -1316,7 +1355,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     //----- calculation of gradI -----// |     //----- calculation of gradI -----// | ||||||
|     sd::MmulHelper::tensorDot(weights, gradO, &columns, {{2,0,1,3},{iC,kH*kW,mC}}, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] |     sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] | ||||||
|     helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW);                                       // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] |     helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW);                                       // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] | ||||||
| 
 | 
 | ||||||
|     if(!isNCHW) { |     if(!isNCHW) { | ||||||
| @ -1326,8 +1365,8 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { | void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
|     BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); |     BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -102,7 +102,7 @@ PLATFORM_IMPL(avgpool2d_bp, ENGINE_CUDA) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong>  expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW,  0,indIOioC,indIiH,indIiH+1}); |     std::vector<Nd4jLong>  expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|     std::vector<Nd4jLong>  expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW,  0,indIOioC,indIiH,indIiH+1}); |     std::vector<Nd4jLong>  expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|  | |||||||
| @ -54,7 +54,7 @@ PLATFORM_IMPL(avgpool3dnew, ENGINE_CUDA) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width; |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width; | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong>  expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong>  expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|     REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); |     REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); | ||||||
| @ -108,7 +108,7 @@ PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CUDA) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                // batch size, input channels, input depth/height/width, output channels, output depth/height/width; |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                // batch size, input channels, input depth/height/width, output channels, output depth/height/width; | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|  | |||||||
| @ -34,22 +34,25 @@ static void conv2dCUDNN(const LaunchContext* context, | |||||||
|                         const int sH, const int sW, |                         const int sH, const int sW, | ||||||
|                         const int pH, const int pW, |                         const int pH, const int pW, | ||||||
|                         const int dH, const int dW, |                         const int dH, const int dW, | ||||||
|                         const int paddingMode, const bool isNCHW) { |                         const int paddingMode, const bool isNCHW, const int wFormat) { | ||||||
|  | 
 | ||||||
|  |     // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC} | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); |     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); | ||||||
|     cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); |     cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); | ||||||
|     if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", err); |     if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", err); | ||||||
| 
 | 
 | ||||||
|     cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; |     cudnnTensorFormat_t format  = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; | ||||||
|  |     cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); | ||||||
| 
 | 
 | ||||||
|     // input descriptor |     // input descriptor | ||||||
|     cudnnTensorDescriptor_t x; |     cudnnTensorDescriptor_t x; | ||||||
|     cudnnCreateTensorDescriptor(&x); |     cudnnCreateTensorDescriptor(&x); | ||||||
|     if(input->ews() == 1) |     if(input->ews() == 1 && input->ordering() == 'c') | ||||||
|         err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); |         err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); |         err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); | ||||||
| @ -58,13 +61,13 @@ static void conv2dCUDNN(const LaunchContext* context, | |||||||
|     // weights descriptor |     // weights descriptor | ||||||
|     cudnnFilterDescriptor_t w; |     cudnnFilterDescriptor_t w; | ||||||
|     cudnnCreateFilterDescriptor(&w); |     cudnnCreateFilterDescriptor(&w); | ||||||
|     err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), CUDNN_TENSOR_NCHW, oC, iC, kH, kW); |     err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), formatW, oC, iC, kH, kW); | ||||||
|     if(err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetFilter4dDescriptor failed", err); |     if(err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetFilter4dDescriptor failed", err); | ||||||
| 
 | 
 | ||||||
|     // output descriptor |     // output descriptor | ||||||
|     cudnnTensorDescriptor_t z; |     cudnnTensorDescriptor_t z; | ||||||
|     cudnnCreateTensorDescriptor(&z); |     cudnnCreateTensorDescriptor(&z); | ||||||
|     if(output->ews() == 1) |     if(output->ews() == 1 && output->ordering() == 'c') | ||||||
|         err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); |         err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); |         err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); | ||||||
| @ -104,10 +107,10 @@ static void conv2dCUDNN(const LaunchContext* context, | |||||||
| 
 | 
 | ||||||
|     // add bias if it is present |     // add bias if it is present | ||||||
|     if (bias != nullptr) { |     if (bias != nullptr) { | ||||||
| 
 |  | ||||||
|         cudnnTensorDescriptor_t b; |         cudnnTensorDescriptor_t b; | ||||||
|         cudnnCreateTensorDescriptor(&b); |         cudnnCreateTensorDescriptor(&b); | ||||||
|         err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); |         // err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); | ||||||
|  |         err = cudnnSetTensor4dDescriptor(b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1); | ||||||
|         if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); |         if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); | ||||||
|         err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); |         err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); | ||||||
|         if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnAddTensor bias failed", err); |         if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnAddTensor bias failed", err); | ||||||
| @ -131,22 +134,23 @@ static void conv2dBpCUDNN(const LaunchContext* context, | |||||||
|                           const int sH, const int sW, |                           const int sH, const int sW, | ||||||
|                           const int pH, const int pW, |                           const int pH, const int pW, | ||||||
|                           const int dH, const int dW, |                           const int dH, const int dW, | ||||||
|                           const int paddingMode, const bool isNCHW) { |                           const int paddingMode, const bool isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); |     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); | ||||||
|     cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); |     cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); | ||||||
|     if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: can't set stream for cuDNN", err); |     if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: can't set stream for cuDNN", err); | ||||||
| 
 | 
 | ||||||
|     cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; |     cudnnTensorFormat_t format  = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; | ||||||
|  |     cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); | ||||||
| 
 | 
 | ||||||
|     // input descriptor |     // input descriptor | ||||||
|     cudnnTensorDescriptor_t x; |     cudnnTensorDescriptor_t x; | ||||||
|     cudnnCreateTensorDescriptor(&x); |     cudnnCreateTensorDescriptor(&x); | ||||||
|     if(input->ews() == 1) |     if(input->ews() == 1 && input->ordering() == 'c') | ||||||
|         err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); |         err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); |         err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); | ||||||
| @ -155,7 +159,7 @@ static void conv2dBpCUDNN(const LaunchContext* context, | |||||||
|     // gradO descriptor |     // gradO descriptor | ||||||
|     cudnnTensorDescriptor_t dz; |     cudnnTensorDescriptor_t dz; | ||||||
|     cudnnCreateTensorDescriptor(&dz); |     cudnnCreateTensorDescriptor(&dz); | ||||||
|     if(gradO->ews() == 1) |     if(gradO->ews() == 1 && gradO->ordering() == 'c') | ||||||
|         err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); |         err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); |         err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); | ||||||
| @ -164,7 +168,7 @@ static void conv2dBpCUDNN(const LaunchContext* context, | |||||||
|     // gradI descriptor |     // gradI descriptor | ||||||
|     cudnnTensorDescriptor_t dx; |     cudnnTensorDescriptor_t dx; | ||||||
|     cudnnCreateTensorDescriptor(&dx); |     cudnnCreateTensorDescriptor(&dx); | ||||||
|     if(gradI->ews() == 1) |     if(gradI->ews() == 1 && gradI->ordering() == 'c') | ||||||
|         err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); |         err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1)); |         err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1)); | ||||||
| @ -173,7 +177,7 @@ static void conv2dBpCUDNN(const LaunchContext* context, | |||||||
|     // gradW descriptor |     // gradW descriptor | ||||||
|     cudnnFilterDescriptor_t dw; |     cudnnFilterDescriptor_t dw; | ||||||
|     cudnnCreateFilterDescriptor(&dw); |     cudnnCreateFilterDescriptor(&dw); | ||||||
|     err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), CUDNN_TENSOR_NCHW, oC, iC, kH, kW); |     err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), formatW, oC, iC, kH, kW); | ||||||
|     if(err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err); |     if(err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err); | ||||||
| 
 | 
 | ||||||
|     // description of convolution |     // description of convolution | ||||||
| @ -220,7 +224,8 @@ static void conv2dBpCUDNN(const LaunchContext* context, | |||||||
|     if(gradB != nullptr) { |     if(gradB != nullptr) { | ||||||
|         cudnnTensorDescriptor_t db; |         cudnnTensorDescriptor_t db; | ||||||
|         cudnnCreateTensorDescriptor(&db); |         cudnnCreateTensorDescriptor(&db); | ||||||
|         err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); |         // err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); | ||||||
|  |         err = cudnnSetTensor4dDescriptor(db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1); | ||||||
|         if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); |         if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); | ||||||
| 
 | 
 | ||||||
|         err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer()); |         err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer()); | ||||||
| @ -251,7 +256,7 @@ static void conv2dBpCUDNN(const LaunchContext* context, | |||||||
| PLATFORM_IMPL(conv2d, ENGINE_CUDA) { | PLATFORM_IMPL(conv2d, ENGINE_CUDA) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, oC] always |     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC] |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC] | ||||||
| 
 | 
 | ||||||
|     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) |     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) | ||||||
| @ -263,7 +268,8 @@ PLATFORM_IMPL(conv2d, ENGINE_CUDA) { | |||||||
|     int dH = INT_ARG(6);                                                        // dilations height |     int dH = INT_ARG(6);                                                        // dilations height | ||||||
|     int dW = INT_ARG(7);                                                        // dilations width |     int dW = INT_ARG(7);                                                        // dilations width | ||||||
|     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME |     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME | ||||||
|     bool isNCHW    = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC |     bool isNCHW    = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW, 1-NHWC | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] | ||||||
| 
 | 
 | ||||||
|     int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height |     int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height | ||||||
|     int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width |     int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width | ||||||
| @ -273,31 +279,35 @@ PLATFORM_IMPL(conv2d, ENGINE_CUDA) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); |     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong>  expectedWeightsShape = {kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if (bias) { |     if (bias) { | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
|         REQUIRE_TRUE((bias->rankOf() == 1 && bias->strideAt(0) == 1) || (bias->rankOf() == 2 && bias->sizeAt(0) == 1 && bias->strideAt(1) == 1) || (bias->rankOf() == 2 && bias->sizeAt(1) == 1 && bias->strideAt(0) == 1), 0, "CUSTOM CONV2D CUDNN OP: bias array should be contiguous in memory !"); |         REQUIRE_TRUE((bias->rankOf() == 1 && bias->strideAt(0) == 1) || (bias->rankOf() == 2 && bias->sizeAt(0) == 1 && bias->strideAt(1) == 1) || (bias->rankOf() == 2 && bias->sizeAt(1) == 1 && bias->strideAt(0) == 1), 0, "CUSTOM CONV2D CUDNN OP: bias array should be contiguous in memory !"); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} |     NDArray* newWeights = weights; // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} | ||||||
|     newWeights->assign(weights->permute({3,2,0,1})); // permute weights (kH, kW, iC, oC  --> oC, iC, kH, kW) |     if(0 == wFormat) { | ||||||
|  |         newWeights = new NDArray(weights->ordering(), isNCHW ? std::vector<Nd4jLong>({oC, iC, kH, kW}) : std::vector<Nd4jLong>({oC, kH, kW, iC}), weights->dataType(), weights->getContext()); | ||||||
|  |         newWeights->assign(weights->permute(isNCHW ? std::vector<int>({3,2,0,1}) : std::vector<int>({3,0,1,2}))); // (kH, kW, iC, oC  --> oC, iC, kH, kW) or (kH, kW, iC, oC  --> oC, kH, kW, iC) | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     NDArray* newInput = input; |     NDArray* newInput = input; | ||||||
|     NDArray* newGradI = nullptr; |     NDArray* newGradI = nullptr; | ||||||
|     if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings |     if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings | ||||||
|         checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); |         checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); | ||||||
| 
 | 
 | ||||||
|     conv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW, paddingMode, isNCHW); |     conv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW, paddingMode, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     if(newInput != input) |     if(newInput != input) | ||||||
|         delete newInput; |         delete newInput; | ||||||
| 
 | 
 | ||||||
|     delete newWeights; |     if(0 == wFormat) | ||||||
|  |         delete newWeights; | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -322,12 +332,12 @@ PLATFORM_CHECK(conv2d, ENGINE_CUDA) { | |||||||
| PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { | PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC] always |     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon |     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon | ||||||
|     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kH, kW, iC, oC] always |     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC] |     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC] | ||||||
| 
 | 
 | ||||||
|     int kH = INT_ARG(0);                                                        // filter(kernel) height |     int kH = INT_ARG(0);                                                        // filter(kernel) height | ||||||
| @ -340,6 +350,7 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width |     int dW = INT_ARG(7);                                                        // dilations width | ||||||
|     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME |     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(input->rankOf()   == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); |     REQUIRE_TRUE(input->rankOf()   == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); | ||||||
|     REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); |     REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); | ||||||
| @ -347,7 +358,7 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     int trueoH, trueoW;          // true output height, width |     int trueoH, trueoW;          // true output height, width | ||||||
|     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); |     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); | ||||||
| @ -355,26 +366,30 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { | |||||||
|     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); |     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); |     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM CONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM CONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if(bias) |     if(bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| 
 | 
 | ||||||
|     NDArray* newGradW   = new NDArray(gradW->ordering(),   {oC, iC, kH, kW}, gradW->dataType(),   gradW->getContext()); // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC} |     NDArray *newWeights = weights, *newGradW = gradW; // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} | ||||||
|     NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kH, kW}, weights->dataType(), weights->getContext()); |     if(0 == wFormat) { | ||||||
| 
 |         newGradW   = new NDArray(gradW->ordering(),   isNCHW ? std::vector<Nd4jLong>({oC, iC, kH, kW}) : std::vector<Nd4jLong>({oC, kH, kW, iC}), gradW->dataType(),   gradW->getContext()); | ||||||
|     newWeights->assign(weights->permute({3,2,0,1})); // permute weights (kH, kW, iC, oC  --> oC, iC, kH, kW) |         newWeights = new NDArray(weights->ordering(), isNCHW ? std::vector<Nd4jLong>({oC, iC, kH, kW}) : std::vector<Nd4jLong>({oC, kH, kW, iC}), weights->dataType(), weights->getContext()); | ||||||
|  |         newWeights->assign(weights->permute(isNCHW ? std::vector<int>({3,2,0,1}) : std::vector<int>({3,0,1,2}))); // (kH, kW, iC, oC  --> oC, iC, kH, kW) or (kH, kW, iC, oC  --> oC, kH, kW, iC) | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     NDArray* newInput = input; |     NDArray* newInput = input; | ||||||
|     NDArray* newGradI = gradI; |     NDArray* newGradI = gradI; | ||||||
|     if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings |     if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings | ||||||
|         checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); |         checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); | ||||||
| 
 | 
 | ||||||
|     conv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO,   newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW); |     conv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO,   newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW,wFormat); | ||||||
| 
 | 
 | ||||||
|     newGradW->permutei({2,3,1,0});  // [oC, iC, kH, kW] -> [kH, kW, iC, oC] |     if(0 == wFormat) { | ||||||
|     gradW->assign(newGradW); |         newGradW->permutei(isNCHW ? std::vector<int>({2,3,1,0}) : std::vector<int>({1,2,3,0})); // (oC, iC, kH, kW --> kH, kW, iC, oC) or (oC, kH, kW, iC --> kH, kW, iC, oC) | ||||||
|  |         gradW->assign(newGradW); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     if(newInput != input) { |     if(newInput != input) { | ||||||
| 
 | 
 | ||||||
| @ -387,8 +402,10 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { | |||||||
|         delete newGradI; |         delete newGradI; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     delete newWeights; |     if(0 == wFormat) { | ||||||
|     delete newGradW; |         delete newWeights; | ||||||
|  |         delete newGradW; | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
|  | |||||||
| @ -34,13 +34,15 @@ static void conv3dCUDNN(const LaunchContext* context, | |||||||
|                         const int sD, const int sH, const int sW, |                         const int sD, const int sH, const int sW, | ||||||
|                         const int pD, const int pH, const int pW, |                         const int pD, const int pH, const int pW, | ||||||
|                         const int dD, const int dH, const int dW, |                         const int dD, const int dH, const int dW, | ||||||
|                         const int paddingMode, const bool isNCDHW) { |                         const int paddingMode, const bool isNCDHW, const int wFormat) { | ||||||
|  | 
 | ||||||
|  |     // cudnn support only one format for weights {oC,iC,kD,kH,kW} | ||||||
| 
 | 
 | ||||||
|     const int numDims = 5; |     const int numDims = 5; | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width; |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width; | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); |     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); | ||||||
|     cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); |     cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); | ||||||
| @ -53,7 +55,7 @@ static void conv3dCUDNN(const LaunchContext* context, | |||||||
|     const std::vector<int> xShape   = {bS, iC, iD, iH, iW}; |     const std::vector<int> xShape   = {bS, iC, iD, iH, iW}; | ||||||
|     const std::vector<int> zShape   = {bS, oC, oD, oH, oW}; |     const std::vector<int> zShape   = {bS, oC, oD, oH, oW}; | ||||||
|     const std::vector<int> wShape   = {oC, iC, kD, kH, kW}; |     const std::vector<int> wShape   = {oC, iC, kD, kH, kW}; | ||||||
|     const std::vector<int> bShape   = {1, (isNCDHW ? oC : 1), 1, 1, (isNCDHW ? 1 : oC)}; |     const std::vector<int> bShape   = {1, oC, 1, 1, 1};         // {1, (isNCDHW ? oC : 1), 1, 1, (isNCDHW ? 1 : oC)}; | ||||||
| 
 | 
 | ||||||
|     const std::vector<int> xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; |     const std::vector<int> xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; | ||||||
|     const std::vector<int> zStrides = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)}; |     const std::vector<int> zStrides = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)}; | ||||||
| @ -120,7 +122,7 @@ static void conv3dCUDNN(const LaunchContext* context, | |||||||
| 
 | 
 | ||||||
|         cudnnTensorDescriptor_t b; |         cudnnTensorDescriptor_t b; | ||||||
|         cudnnCreateTensorDescriptor(&b); |         cudnnCreateTensorDescriptor(&b); | ||||||
|         err = cudnnSetTensorNdDescriptorEx(b, format, cudnnDataType(bias->dataType()), numDims, bShape.data()); |         err = cudnnSetTensorNdDescriptorEx(b, /*format*/CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), numDims, bShape.data()); | ||||||
|         if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor for bias failed", err); |         if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor for bias failed", err); | ||||||
|         err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); |         err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); | ||||||
|         if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnAddTensor bias failed", err); |         if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnAddTensor bias failed", err); | ||||||
| @ -144,13 +146,15 @@ static void conv3dBpCUDNN(const LaunchContext* context, | |||||||
|                           const int sD, const int sH, const int sW, |                           const int sD, const int sH, const int sW, | ||||||
|                           const int pD, const int pH, const int pW, |                           const int pD, const int pH, const int pW, | ||||||
|                           const int dD, const int dH, const int dW, |                           const int dD, const int dH, const int dW, | ||||||
|                           const int paddingMode, const bool isNCDHW) { |                           const int paddingMode, const bool isNCDHW, const int wFormat) { | ||||||
|  | 
 | ||||||
|  |     // cudnn supports only two formats {oC,iC,kD,kH,kW} and {oC,kD,kH,kW,iC} for weights/gradW | ||||||
| 
 | 
 | ||||||
|     const int numDims = 5; |     const int numDims = 5; | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width; |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width; | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); |     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); | ||||||
|     cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); |     cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); | ||||||
| @ -170,6 +174,7 @@ static void conv3dBpCUDNN(const LaunchContext* context, | |||||||
|     const std::vector<int> dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)}; |     const std::vector<int> dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)}; | ||||||
| 
 | 
 | ||||||
|     cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; |     cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; | ||||||
|  |     cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); | ||||||
| 
 | 
 | ||||||
|     // input descriptor |     // input descriptor | ||||||
|     cudnnTensorDescriptor_t x; |     cudnnTensorDescriptor_t x; | ||||||
| @ -201,7 +206,7 @@ static void conv3dBpCUDNN(const LaunchContext* context, | |||||||
|     // gradW descriptor |     // gradW descriptor | ||||||
|     cudnnFilterDescriptor_t dw; |     cudnnFilterDescriptor_t dw; | ||||||
|     cudnnCreateFilterDescriptor(&dw); |     cudnnCreateFilterDescriptor(&dw); | ||||||
|     err = cudnnSetFilterNdDescriptor(dw, cudnnDataType(gradW->dataType()), CUDNN_TENSOR_NCHW, numDims, wShape.data()); |     err = cudnnSetFilterNdDescriptor(dw, cudnnDataType(gradW->dataType()), formatW, numDims, wShape.data()); | ||||||
|     if(err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetFilterNdDescriptor failed", err); |     if(err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetFilterNdDescriptor failed", err); | ||||||
| 
 | 
 | ||||||
|     // description of convolution |     // description of convolution | ||||||
| @ -280,7 +285,7 @@ static void conv3dBpCUDNN(const LaunchContext* context, | |||||||
| PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) { | PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kD, kH, kW, iC, oC] always |     auto weights = INPUT_VARIABLE(1);                                    // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC] |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC] | ||||||
|     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) |     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) | ||||||
| 
 | 
 | ||||||
| @ -301,34 +306,39 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) { | |||||||
|     int dW = INT_ARG(11);                                                       // dilations width |     int dW = INT_ARG(11);                                                       // dilations width | ||||||
|     int paddingMode = INT_ARG(12);                                              // 0-SAME,  1-VALID |     int paddingMode = INT_ARG(12);                                              // 0-SAME,  1-VALID | ||||||
|     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW |     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW | ||||||
|  |     int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0;         // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); |     REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width; |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width; | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); |     ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV3D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV3D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV3D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV3D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| 
 | 
 | ||||||
|     NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} |     NDArray* newWeights = weights; // cudnn support only one format {oC,iC,kD,kH,kW} | ||||||
|     newWeights->assign(weights->permute({4,3,0,1,2})); // permute weights (kD, kH, kW, iC, oC  --> oC, iC, kD, kH, kW) |     if(1 != wFormat) { | ||||||
|  |         newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext()); | ||||||
|  |         newWeights->assign(weights->permute(0 == wFormat ? std::vector<int>({4,3,0,1,2}) : std::vector<int>({0,4,1,2,3})));  // kD, kH, kW, iC, oC  --> oC, iC, kD, kH, kW   or oC, kD, kH, kW, iC  --> oC, iC, kD, kH, kW | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     NDArray* newInput = input; |     NDArray* newInput = input; | ||||||
|     NDArray* newGradI = nullptr; |     NDArray* newGradI = nullptr; | ||||||
|     if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings |     if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings | ||||||
|         checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); |         checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); | ||||||
| 
 | 
 | ||||||
|     conv3dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW, paddingMode, isNCDHW); |     conv3dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW, paddingMode, isNCDHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     if(newInput != input) |     if(newInput != input) | ||||||
|         delete newInput; |         delete newInput; | ||||||
| 
 | 
 | ||||||
|     delete newWeights; |     if(1 != wFormat) | ||||||
|  |         delete newWeights; | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -337,7 +347,7 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) { | |||||||
| PLATFORM_CHECK(conv3dnew, ENGINE_CUDA) { | PLATFORM_CHECK(conv3dnew, ENGINE_CUDA) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kD, kH, kW, iC, oC] always |     auto weights = INPUT_VARIABLE(1);                                    // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC] |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC] | ||||||
| 
 | 
 | ||||||
|     int paddingMode = INT_ARG(12);                                       // 0-SAME,  1-VALID |     int paddingMode = INT_ARG(12);                                       // 0-SAME,  1-VALID | ||||||
| @ -353,12 +363,12 @@ PLATFORM_CHECK(conv3dnew, ENGINE_CUDA) { | |||||||
| PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { | PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kD, kH, kW, iC, oC] always |     auto weights = INPUT_VARIABLE(1);                                                // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon |     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon | ||||||
|     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kD, kH, kW, iC, oC] always |     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC] |     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC] | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(input->rankOf()   == 5, 0, "CONV3D_BP CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); |     REQUIRE_TRUE(input->rankOf()   == 5, 0, "CONV3D_BP CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); | ||||||
| @ -379,10 +389,11 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { | |||||||
|     int dW = INT_ARG(11);                                                       // dilations width |     int dW = INT_ARG(11);                                                       // dilations width | ||||||
|     int paddingMode = INT_ARG(12);                                              // 1-SAME,  0-VALID |     int paddingMode = INT_ARG(12);                                              // 1-SAME,  0-VALID | ||||||
|     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW |     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW | ||||||
|  |     int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0;         // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width; |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width; | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     int trueoD, trueoH, trueoW;          // true output depth/height/width |     int trueoD, trueoH, trueoW;          // true output depth/height/width | ||||||
|     ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); |     ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); | ||||||
| @ -390,7 +401,7 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { | |||||||
|     REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D_BP CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); |     REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D_BP CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CONV3D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CONV3D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(gradW->isSameShape(expectedWeightsShape), 0, "CONV3D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(gradW->isSameShape(expectedWeightsShape), 0, "CONV3D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if(bias) |     if(bias) | ||||||
| @ -398,20 +409,25 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { | |||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); |     ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     NDArray* newGradW   = new NDArray(gradW->ordering(),   {oC, iC, kD, kH, kW}, gradW->dataType(),   gradW->getContext()); // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC} |     NDArray *newWeights = weights, *newGradW = gradW; // cudnn support only two formats {oC,iC,kD,kH,kW} and {oC,kD,kH,kW,iC} | ||||||
|     NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext()); |     if(0 == wFormat) { | ||||||
| 
 |         newGradW   = new NDArray(gradW->ordering(),   isNCDHW ? std::vector<Nd4jLong>({oC, iC, kD, kH, kW}) : std::vector<Nd4jLong>({oC, kD, kH, kW, iC}), gradW->dataType(),   gradW->getContext()); | ||||||
|     newWeights->assign(weights->permute({4,3,0,1,2})); // permute weights (kD, kH, kW, iC, oC  --> oC, iC, kD, kH, kW) |         newWeights = new NDArray(weights->ordering(), isNCDHW ? std::vector<Nd4jLong>({oC, iC, kD, kH, kW}) : std::vector<Nd4jLong>({oC, kD, kH, kW, iC}), weights->dataType(), weights->getContext()); | ||||||
|  |         newWeights->assign(weights->permute(isNCDHW ? std::vector<int>({4,3,0,1,2}) : std::vector<int>({4,0,1,2,3}))); // (kD, kH, kW, iC, oC  --> oC, iC, kD, kH, kW) or (kD, kH, kW, iC, oC  --> oC, kD, kH, kW, iC) | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     NDArray* newInput = input; |     NDArray* newInput = input; | ||||||
|     NDArray* newGradI = gradI; |     NDArray* newGradI = gradI; | ||||||
|     if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings |     if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings | ||||||
|         checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); |         checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); | ||||||
| 
 | 
 | ||||||
|     conv3dBpCUDNN(block.launchContext(), newInput, newWeights, gradO,   newGradI, newGradW, gradB, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW,paddingMode,isNCDHW); |     conv3dBpCUDNN(block.launchContext(), newInput, newWeights, gradO,   newGradI, newGradW, gradB, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW,paddingMode,isNCDHW,wFormat); | ||||||
|  | 
 | ||||||
|  |     if(0 == wFormat) { | ||||||
|  |         newGradW->permutei(isNCDHW ? std::vector<int>({2,3,4,1,0}) : std::vector<int>({1,2,3,4,0})); // (oC, iC, kD, kH, kW --> kD, kH, kW, iC, oC) or (oC, kD, kH, kW, iC --> kD, kH, kW, iC, oC) | ||||||
|  |         gradW->assign(newGradW); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     newGradW->permutei({2,3,4,1,0});    // [oC, iC, kD, kH, kW] -> [kD, kH, kW, iC, oC] |  | ||||||
|     gradW->assign(newGradW); |  | ||||||
| 
 | 
 | ||||||
|     if(newInput != input) { |     if(newInput != input) { | ||||||
| 
 | 
 | ||||||
| @ -424,8 +440,10 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { | |||||||
|         delete newGradI; |         delete newGradI; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     delete newWeights; |     if(0 == wFormat) { | ||||||
|     delete newGradW; |         delete newWeights; | ||||||
|  |         delete newGradW; | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -433,7 +451,7 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { | |||||||
| PLATFORM_CHECK(conv3dnew_bp, ENGINE_CUDA) { | PLATFORM_CHECK(conv3dnew_bp, ENGINE_CUDA) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kD, kH, kW, iC, oC] always |     auto weights = INPUT_VARIABLE(1);                                                // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -124,7 +124,7 @@ void pooling2dCUDNN(const LaunchContext* context, | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); |     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); | ||||||
|     cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); |     cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); | ||||||
| @ -135,7 +135,7 @@ void pooling2dCUDNN(const LaunchContext* context, | |||||||
|     // input descriptor |     // input descriptor | ||||||
|     cudnnTensorDescriptor_t x; |     cudnnTensorDescriptor_t x; | ||||||
|     cudnnCreateTensorDescriptor(&x); |     cudnnCreateTensorDescriptor(&x); | ||||||
|     if(input->ews() == 1) |     if(input->ews() == 1 && input->ordering() == 'c') | ||||||
|         err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); |         err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); |         err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); | ||||||
| @ -144,7 +144,7 @@ void pooling2dCUDNN(const LaunchContext* context, | |||||||
|     // output descriptor |     // output descriptor | ||||||
|     cudnnTensorDescriptor_t z; |     cudnnTensorDescriptor_t z; | ||||||
|     cudnnCreateTensorDescriptor(&z); |     cudnnCreateTensorDescriptor(&z); | ||||||
|     if(output->ews() == 1) |     if(output->ews() == 1 && output->ordering() == 'c') | ||||||
|         err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); |         err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); |         err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); | ||||||
| @ -187,7 +187,7 @@ void pooling2dBpCUDNN(const LaunchContext* context, | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); |     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); | ||||||
|     cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); |     cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); | ||||||
| @ -198,7 +198,7 @@ void pooling2dBpCUDNN(const LaunchContext* context, | |||||||
|     // input and gradI descriptor |     // input and gradI descriptor | ||||||
|     cudnnTensorDescriptor_t x; |     cudnnTensorDescriptor_t x; | ||||||
|     cudnnCreateTensorDescriptor(&x); |     cudnnCreateTensorDescriptor(&x); | ||||||
|     if(input->ews() == 1) |     if(input->ews() == 1 && input->ordering() == 'c') | ||||||
|         err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); |         err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); |         err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); | ||||||
| @ -207,7 +207,7 @@ void pooling2dBpCUDNN(const LaunchContext* context, | |||||||
|     // gradO descriptor |     // gradO descriptor | ||||||
|     cudnnTensorDescriptor_t dz; |     cudnnTensorDescriptor_t dz; | ||||||
|     cudnnCreateTensorDescriptor(&dz); |     cudnnCreateTensorDescriptor(&dz); | ||||||
|     if(gradO->ews() == 1) |     if(gradO->ews() == 1 && gradO->ordering() == 'c') | ||||||
|         err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); |         err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); |         err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); | ||||||
| @ -255,7 +255,7 @@ void pooling3dCUDNN(const LaunchContext* context, | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width; |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width; | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     const int pSizes[] = {pD, pH, pW}; |     const int pSizes[] = {pD, pH, pW}; | ||||||
|     const int sSizes[] = {sD, sH, sW}; |     const int sSizes[] = {sD, sH, sW}; | ||||||
| @ -272,7 +272,7 @@ void pooling3dCUDNN(const LaunchContext* context, | |||||||
|     // input descriptor |     // input descriptor | ||||||
|     cudnnTensorDescriptor_t x; |     cudnnTensorDescriptor_t x; | ||||||
|     cudnnCreateTensorDescriptor(&x); |     cudnnCreateTensorDescriptor(&x); | ||||||
|     if(input->ews() == 1) |     if(input->ews() == 1 && input->ordering() == 'c') | ||||||
|         err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); |         err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); |         err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); | ||||||
| @ -281,7 +281,7 @@ void pooling3dCUDNN(const LaunchContext* context, | |||||||
|     // output descriptor |     // output descriptor | ||||||
|     cudnnTensorDescriptor_t z; |     cudnnTensorDescriptor_t z; | ||||||
|     cudnnCreateTensorDescriptor(&z); |     cudnnCreateTensorDescriptor(&z); | ||||||
|     if(output->ews() == 1) |     if(output->ews() == 1 && output->ordering() == 'c') | ||||||
|         err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape); |         err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape, zStrides); |         err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape, zStrides); | ||||||
| @ -330,7 +330,7 @@ void pooling3dBpCUDNN(const LaunchContext* context, | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                // batch size, input channels, input depth/height/width, output channels, output depth/height/width; |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                // batch size, input channels, input depth/height/width, output channels, output depth/height/width; | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     const int pSizes[] = {pD, pH, pW}; |     const int pSizes[] = {pD, pH, pW}; | ||||||
|     const int sSizes[] = {sD, sH, sW}; |     const int sSizes[] = {sD, sH, sW}; | ||||||
| @ -347,7 +347,7 @@ void pooling3dBpCUDNN(const LaunchContext* context, | |||||||
|     // input and gradI descriptor |     // input and gradI descriptor | ||||||
|     cudnnTensorDescriptor_t x; |     cudnnTensorDescriptor_t x; | ||||||
|     cudnnCreateTensorDescriptor(&x); |     cudnnCreateTensorDescriptor(&x); | ||||||
|     if(input->ews() == 1) |     if(input->ews() == 1 && input->ordering() == 'c') | ||||||
|         err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); |         err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); |         err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); | ||||||
| @ -356,7 +356,7 @@ void pooling3dBpCUDNN(const LaunchContext* context, | |||||||
|     // gradO descriptor |     // gradO descriptor | ||||||
|     cudnnTensorDescriptor_t dz; |     cudnnTensorDescriptor_t dz; | ||||||
|     cudnnCreateTensorDescriptor(&dz); |     cudnnCreateTensorDescriptor(&dz); | ||||||
|     if(gradO->ews() == 1) |     if(gradO->ews() == 1 && gradO->ordering() == 'c') | ||||||
|         err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape); |         err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape, dzStrides); |         err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape, dzStrides); | ||||||
|  | |||||||
| @ -39,14 +39,14 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context, | |||||||
|     // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) |     // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) | ||||||
| 
 | 
 | ||||||
|     // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc |     // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc | ||||||
|     // weights [iC, mC, kH, kW], mkl doesn't support this format, so we'll make permute |     // weights [iC, mC, kH, kW] | ||||||
|     // bias [oC], may be nullptr |     // bias [oC], may be nullptr | ||||||
|     // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc |     // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc | ||||||
|     // oC = iC*mC |     // oC = iC*mC | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; |     int bS, iC, iH, iW, mC, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;           // corresponding indexes |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;           // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weights->sizeAt(1); |     mC = weights->sizeAt(1); | ||||||
| 
 | 
 | ||||||
|     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); |     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); | ||||||
| @ -58,7 +58,7 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context, | |||||||
|     // input descriptor |     // input descriptor | ||||||
|     cudnnTensorDescriptor_t x; |     cudnnTensorDescriptor_t x; | ||||||
|     cudnnCreateTensorDescriptor(&x); |     cudnnCreateTensorDescriptor(&x); | ||||||
|     if(input->ews() == 1) |     if(input->ews() == 1 && input->ordering() == 'c') | ||||||
|         err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); |         err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); |         err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); | ||||||
| @ -73,7 +73,7 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context, | |||||||
|     // output descriptor |     // output descriptor | ||||||
|     cudnnTensorDescriptor_t z; |     cudnnTensorDescriptor_t z; | ||||||
|     cudnnCreateTensorDescriptor(&z); |     cudnnCreateTensorDescriptor(&z); | ||||||
|     if(output->ews() == 1) |     if(output->ews() == 1 && output->ordering() == 'c') | ||||||
|         err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); |         err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); |         err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); | ||||||
| @ -117,7 +117,8 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context, | |||||||
| 
 | 
 | ||||||
|         cudnnTensorDescriptor_t b; |         cudnnTensorDescriptor_t b; | ||||||
|         cudnnCreateTensorDescriptor(&b); |         cudnnCreateTensorDescriptor(&b); | ||||||
|         err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); |         // err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); | ||||||
|  |         err = cudnnSetTensor4dDescriptor(b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1); | ||||||
|         if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); |         if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); | ||||||
|         err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); |         err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); | ||||||
|         if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnAddTensor bias failed", err); |         if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnAddTensor bias failed", err); | ||||||
| @ -146,14 +147,14 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context, | |||||||
|     // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) |     // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) | ||||||
| 
 | 
 | ||||||
|     // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc |     // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc | ||||||
|     // weights, gradW [iC, mC, kH, kW], mkl doesn't support this format, so we'll make permute |     // weights, gradW [iC, mC, kH, kW] | ||||||
|     // gradB [oC], may be nullptr |     // gradB [oC], may be nullptr | ||||||
|     // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc |     // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc | ||||||
|     // oC = iC*mC |     // oC = iC*mC | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; |     int bS, iC, iH, iW, mC, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;           // corresponding indexes |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;           // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weights->sizeAt(1); |     mC = weights->sizeAt(1); | ||||||
| 
 | 
 | ||||||
|     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); |     auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle()); | ||||||
| @ -165,7 +166,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context, | |||||||
|     // input descriptor |     // input descriptor | ||||||
|     cudnnTensorDescriptor_t x; |     cudnnTensorDescriptor_t x; | ||||||
|     cudnnCreateTensorDescriptor(&x); |     cudnnCreateTensorDescriptor(&x); | ||||||
|     if(input->ews() == 1) |     if(input->ews() == 1 && input->ordering() == 'c') | ||||||
|         err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); |         err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); |         err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); | ||||||
| @ -174,7 +175,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context, | |||||||
|     // gradO descriptor |     // gradO descriptor | ||||||
|     cudnnTensorDescriptor_t dz; |     cudnnTensorDescriptor_t dz; | ||||||
|     cudnnCreateTensorDescriptor(&dz); |     cudnnCreateTensorDescriptor(&dz); | ||||||
|     if(gradO->ews() == 1) |     if(gradO->ews() == 1 && gradO->ordering() == 'c') | ||||||
|         err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); |         err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); |         err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); | ||||||
| @ -183,7 +184,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context, | |||||||
|     // gradI descriptor |     // gradI descriptor | ||||||
|     cudnnTensorDescriptor_t dx; |     cudnnTensorDescriptor_t dx; | ||||||
|     cudnnCreateTensorDescriptor(&dx); |     cudnnCreateTensorDescriptor(&dx); | ||||||
|     if(gradI->ews() == 1) |     if(gradI->ews() == 1 && gradI->ordering() == 'c') | ||||||
|         err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); |         err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); | ||||||
|     else |     else | ||||||
|         err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1)); |         err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1)); | ||||||
| @ -241,7 +242,8 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context, | |||||||
|     if(gradB != nullptr) { |     if(gradB != nullptr) { | ||||||
|         cudnnTensorDescriptor_t db; |         cudnnTensorDescriptor_t db; | ||||||
|         cudnnCreateTensorDescriptor(&db); |         cudnnCreateTensorDescriptor(&db); | ||||||
|         err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); |         // err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); | ||||||
|  |         err = cudnnSetTensor4dDescriptor(db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1); | ||||||
|         if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); |         if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); | ||||||
| 
 | 
 | ||||||
|         err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer()); |         err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer()); | ||||||
| @ -272,7 +274,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context, | |||||||
| PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) { | PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, mC] always |     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC] = iC*mC |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC] = iC*mC | ||||||
| 
 | 
 | ||||||
|     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) |     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) | ||||||
| @ -290,22 +292,31 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width |     int dW = INT_ARG(7);                                                        // dilations width | ||||||
|     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME |     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME | ||||||
|     int isNCHW      = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;      // INT_ARG(9): 0-NCHW,  1-NHWC |     int isNCHW      = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;      // INT_ARG(9): 0-NCHW,  1-NHWC | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width |     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weights->sizeAt(indWmC);                           // channels multiplier |     mC = weights->sizeAt(indWmC);                           // channels multiplier | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); |     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "DEPTHWISECONV2D CUDNN OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); |     REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "DEPTHWISECONV2D CUDNN OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| 
 | 
 | ||||||
|     NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support format {oC, iC/groupCount, kH, kW} |     std::vector<int> wPermut;     // cudnn support format {oC, iC/groupCount, kH, kW} only, mC = 1, oC = iC (groupCount == iC) that is {iC, mC, kH, kW} in our case | ||||||
|     newWeights->assign(weights->permute({2,3,0,1})); // assign permuted weights (kH, kW, iC, mC  --> iC, mC, kH, kW) |     if(0 == wFormat) | ||||||
|  |         wPermut = {2,3,0,1};         // kH, kW, iC, mC -> iC, mC, kH, kW | ||||||
|  |     else if(1 == wFormat) | ||||||
|  |         wPermut = {1,0,2,3};         // mC, iC, kH, kW -> iC, mC, kH, kW | ||||||
|  |     else | ||||||
|  |         wPermut = {3,0,1,2};         // mC, kH, kW, iC -> iC, mC, kH, kW | ||||||
|  | 
 | ||||||
|  |     NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); | ||||||
|  |     newWeights->assign(weights->permute(wPermut)); | ||||||
| 
 | 
 | ||||||
|     NDArray* newInput = input; |     NDArray* newInput = input; | ||||||
|     NDArray* newGradI = nullptr; |     NDArray* newGradI = nullptr; | ||||||
| @ -326,12 +337,13 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) { | |||||||
| PLATFORM_CHECK(depthwise_conv2d, ENGINE_CUDA) { | PLATFORM_CHECK(depthwise_conv2d, ENGINE_CUDA) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, mC] always |     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC] = iC*mC |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC] = iC*mC | ||||||
| 
 | 
 | ||||||
|     const int paddingMode = INT_ARG(8);                                  // 0-VALID, 1-SAME, 2-CAUSAL |     const int paddingMode = INT_ARG(8);                                  // 0-VALID, 1-SAME, 2-CAUSAL | ||||||
|  |     const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;       // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] | ||||||
| 
 | 
 | ||||||
|     const int mC = weights->sizeAt(3); |     const int mC = weights->sizeAt(0 == wFormat ? 3 : 0); | ||||||
| 
 | 
 | ||||||
|     const bool badInputType   = input->dataType()   != DataType::DOUBLE && input->dataType()   != DataType::FLOAT32 && input->dataType()   != DataType::HALF; |     const bool badInputType   = input->dataType()   != DataType::DOUBLE && input->dataType()   != DataType::FLOAT32 && input->dataType()   != DataType::HALF; | ||||||
|     const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; |     const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; | ||||||
| @ -344,12 +356,12 @@ PLATFORM_CHECK(depthwise_conv2d, ENGINE_CUDA) { | |||||||
| PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { | PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, mC] always |     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] = [iC*mC] |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] = [iC*mC] | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon |     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon | ||||||
|     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kH, kW, iC, mC] always |     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC] |     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC] | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(input->rankOf()   == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); |     REQUIRE_TRUE(input->rankOf()   == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); | ||||||
| @ -366,10 +378,11 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width |     int dW = INT_ARG(7);                                                        // dilations width | ||||||
|     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME |     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width |     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weights->sizeAt(indWmC);                           // channels multiplier |     mC = weights->sizeAt(indWmC);                           // channels multiplier | ||||||
| 
 | 
 | ||||||
|     int trueoH, trueoW;          // correct output height, width |     int trueoH, trueoW;          // correct output height, width | ||||||
| @ -378,17 +391,30 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { | |||||||
|     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); |     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if(bias) |     if(bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| 
 | 
 | ||||||
|  |     std::vector<int> wPermut, gradWPermut;     // cudnn support format {oC, iC/groupCount, kH, kW} only, mC = 1, oC = iC (groupCount == iC) that is {iC, mC, kH, kW} | ||||||
|  |     if(0 == wFormat) { | ||||||
|  |         wPermut = {2,3,0,1};         // kH, kW, iC, mC -> iC, mC, kH, kW | ||||||
|  |         gradWPermut = {2,3,0,1};     // iC, mC, kH, kW -> kH, kW, iC, mC | ||||||
|  |     } | ||||||
|  |     else if(1 == wFormat) { | ||||||
|  |         wPermut = {1,0,2,3};         // mC, iC, kH, kW -> iC, mC, kH, kW | ||||||
|  |         gradWPermut = {1,0,2,3};     // iC, mC, kH, kW -> mC, iC, kH, kW | ||||||
|  |     } | ||||||
|  |     else { | ||||||
|  |         wPermut = {3,0,1,2};         // mC, kH, kW, iC -> iC, mC, kH, kW | ||||||
|  |         gradWPermut = {1,2,3,0};     // iC, mC, kH, kW -> mC, kH, kW, iC | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     NDArray* newGradW   = new NDArray(gradW->ordering(),   {iC, mC, kH, kW}, gradW->dataType(),   gradW->getContext());     // cudnn support format {oC, iC/groupCount, kH, kW} |     NDArray* newGradW   = new NDArray(gradW->ordering(),   {iC, mC, kH, kW}, gradW->dataType(),   gradW->getContext()); | ||||||
|     NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); |     NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); | ||||||
| 
 | 
 | ||||||
|     newWeights->assign(weights->permute({2,3,0,1})); // assign permuted weights (kH, kW, iC, mC  --> iC, mC, kH, kW) |     newWeights->assign(weights->permute(wPermut)); | ||||||
| 
 | 
 | ||||||
|     NDArray* newInput = input; |     NDArray* newInput = input; | ||||||
|     NDArray* newGradI = gradI; |     NDArray* newGradI = gradI; | ||||||
| @ -397,7 +423,7 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { | |||||||
| 
 | 
 | ||||||
|     depthwiseConv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO,   newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW); |     depthwiseConv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO,   newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW); | ||||||
| 
 | 
 | ||||||
|     newGradW->permutei({2,3,0,1});  // [iC, mC, kH, kW] -> [kH, kW, iC, mC] |     newGradW->permutei(gradWPermut); | ||||||
|     gradW->assign(newGradW); |     gradW->assign(newGradW); | ||||||
| 
 | 
 | ||||||
|     if(newInput != input) { |     if(newInput != input) { | ||||||
| @ -420,14 +446,15 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { | |||||||
| PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CUDA) { | PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CUDA) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, mC] always |     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] = [iC*mC] |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] = [iC*mC] | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next | ||||||
| 
 | 
 | ||||||
|     const int paddingMode = INT_ARG(8);                                             // 0-VALID, 1-SAME, 2-CAUSAL |     const int paddingMode = INT_ARG(8);                                             // 0-VALID, 1-SAME, 2-CAUSAL | ||||||
|     const int isNCHW      = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;    // INT_ARG(9): 0-NCHW, 1-NHWC |     const int isNCHW      = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;    // INT_ARG(9): 0-NCHW, 1-NHWC | ||||||
|  |     const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;       // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] | ||||||
| 
 | 
 | ||||||
|     const int mC = weights->sizeAt(3); |     const int mC = weights->sizeAt(0 == wFormat ? 3 : 0); | ||||||
| 
 | 
 | ||||||
|     const bool badInputType   = input->dataType()   != DataType::DOUBLE && input->dataType()   != DataType::FLOAT32 && input->dataType()   != DataType::HALF; |     const bool badInputType   = input->dataType()   != DataType::DOUBLE && input->dataType()   != DataType::FLOAT32 && input->dataType()   != DataType::HALF; | ||||||
|     const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; |     const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; | ||||||
|  | |||||||
| @ -98,7 +98,7 @@ PLATFORM_IMPL(maxpool2d_bp, ENGINE_CUDA) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width; | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong>  expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW,  0,indIOioC,indIiH,indIiH+1}); |     std::vector<Nd4jLong>  expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|     std::vector<Nd4jLong>  expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW,  0,indIOioC,indIiH,indIiH+1}); |     std::vector<Nd4jLong>  expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|  | |||||||
| @ -54,7 +54,7 @@ PLATFORM_IMPL(maxpool3dnew, ENGINE_CUDA) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width; |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width; | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|     REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); |     REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); | ||||||
| @ -106,7 +106,7 @@ PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CUDA) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;               // batch size, input channels, input depth/height/width, output channels, output depth/height/width; |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;               // batch size, input channels, input depth/height/width, output channels, output depth/height/width; | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|  | |||||||
| @ -60,7 +60,7 @@ PLATFORM_IMPL(avgpool2d, ENGINE_CPU) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     if (paddingMode) |     if (paddingMode) | ||||||
|         ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); |         ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); | ||||||
| @ -105,7 +105,7 @@ PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW,  0,indIOioC,indIiH,indIiH+1}); |     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW,  0,indIOioC,indIiH,indIiH+1}); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|  | |||||||
| @ -61,7 +61,7 @@ PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     if(paddingMode)                       // SAME
 |     if(paddingMode)                       // SAME
 | ||||||
|         ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); |         ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); | ||||||
| @ -109,7 +109,7 @@ PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|  | |||||||
| @ -91,12 +91,12 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray | |||||||
|     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(dims, type, format); |     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(dims, type, format); | ||||||
|     dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); |     dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); | ||||||
| 
 | 
 | ||||||
|     mkldnnUtils::setBlockStrides(x, xRank, x_user_md); |     mkldnnUtils::setBlockStrides(x, x_user_md); | ||||||
|     // z, output
 |     // z, output
 | ||||||
|     dnnl::memory::desc z_mkl_md  = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc z_mkl_md  = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format); |     dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format); | ||||||
| 
 | 
 | ||||||
|     mkldnnUtils::setBlockStrides(z, xRank, z_user_md); |     mkldnnUtils::setBlockStrides(z, z_user_md); | ||||||
| 
 | 
 | ||||||
|     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); |     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||||
| 
 | 
 | ||||||
| @ -112,9 +112,9 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray | |||||||
|     // provide memory and check whether reorder is required
 |     // provide memory and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // x
 |     // x
 | ||||||
|     mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC); |     mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
|      | 
 | ||||||
|     // z 
 |     // z
 | ||||||
|     auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); |     auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); | ||||||
|     const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc(); |     const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc(); | ||||||
|     auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem; |     auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem; | ||||||
| @ -207,19 +207,19 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const | |||||||
|     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(dims, type, format); |     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(dims, type, format); | ||||||
|     dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); |     dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); | ||||||
| 
 | 
 | ||||||
|     mkldnnUtils::setBlockStrides(x, xRank, x_user_md); |     mkldnnUtils::setBlockStrides(x, x_user_md); | ||||||
|      | 
 | ||||||
|     // dLdO
 |     // dLdO
 | ||||||
|     dnnl::memory::desc dLdO_mkl_md  = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc dLdO_mkl_md  = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); |     dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); | ||||||
| 
 | 
 | ||||||
|     mkldnnUtils::setBlockStrides(dLdO, xRank, dLdO_user_md); |     mkldnnUtils::setBlockStrides(dLdO, dLdO_user_md); | ||||||
| 
 | 
 | ||||||
|     // dLdI
 |     // dLdI
 | ||||||
|     dnnl::memory::desc dLdI_mkl_md  = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc dLdI_mkl_md  = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format); |     dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format); | ||||||
| 
 | 
 | ||||||
|     mkldnnUtils::setBlockStrides(dLdI, xRank, dLdI_user_md); |     mkldnnUtils::setBlockStrides(dLdI, dLdI_user_md); | ||||||
| 
 | 
 | ||||||
|     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); |     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||||
| 
 | 
 | ||||||
| @ -239,10 +239,10 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const | |||||||
|     // provide memory and check whether reorder is required
 |     // provide memory and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // x
 |     // x
 | ||||||
|     mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_bp_prim_desc.src_desc(), DNNL_ARG_SRC); |     mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|     // dLdO
 |     // dLdO
 | ||||||
|     mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, args, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); |     mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); | ||||||
| 
 | 
 | ||||||
|     // mean
 |     // mean
 | ||||||
|     auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); |     auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); | ||||||
|  | |||||||
| @ -38,13 +38,13 @@ namespace platforms { | |||||||
| static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, | static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, | ||||||
|                           const NDArray *bias, NDArray *output, |                           const NDArray *bias, NDArray *output, | ||||||
|                           const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, |                           const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, | ||||||
|                           const int paddingMode, const int isNCHW) { |                           const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // weights [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW]
 |     // mkl support weights in [oC, iC, kH, kW] format only
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW;       // dH == 1 for causal mode in conv1d
 |     const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW;       // dH == 1 for causal mode in conv1d
 | ||||||
| 
 | 
 | ||||||
| @ -53,8 +53,8 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, | |||||||
|     dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; |     dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; | ||||||
|     dnnl::memory::dims dilation  = { dH-1, dW-1}; |     dnnl::memory::dims dilation  = { dH-1, dW-1}; | ||||||
| 
 | 
 | ||||||
|     auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; |     auto xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; | ||||||
|     dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; |     dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims xDims = {bS, iC, iH, iW}; |     dnnl::memory::dims xDims = {bS, iC, iH, iW}; | ||||||
|     dnnl::memory::dims wDims = {oC, iC, kH, kW}; |     dnnl::memory::dims wDims = {oC, iC, kH, kW}; | ||||||
| @ -66,17 +66,29 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, | |||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); |     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(input, 4, x_user_md); |     mkldnnUtils::setBlockStrides(input, x_user_md); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); |     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); | ||||||
|     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { | ||||||
|     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3);   // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
 |         w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); |         uint i0, i1, i2, i3; | ||||||
|     w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); |         if(0 == wFormat) { | ||||||
|     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); |             i0 = 3; i1 = 2; i2 = 0; i3 = 1;     // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
 | ||||||
|  |         } | ||||||
|  |         else if(1 == wFormat) { | ||||||
|  |             i0 = 0; i1 = 1; i2 = 2; i3 = 3; | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             i0 = 0; i1 = 3; i2 = 1; i3 = 2;     // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
 | ||||||
|  |         } | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     // bias
 |     // bias
 | ||||||
|     dnnl::memory::desc b_mkl_md; |     dnnl::memory::desc b_mkl_md; | ||||||
| @ -85,9 +97,8 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, | |||||||
| 
 | 
 | ||||||
|     // output
 |     // output
 | ||||||
|     dnnl::memory::desc z_mkl_md  = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc z_mkl_md  = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); |     dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); | ||||||
| 
 |     mkldnnUtils::setBlockStrides(output, z_user_md); | ||||||
|     mkldnnUtils::setBlockStrides(output, 4, z_user_md); |  | ||||||
| 
 | 
 | ||||||
|     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); |     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||||
| 
 | 
 | ||||||
| @ -103,10 +114,10 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, | |||||||
|     // provide memory buffers and check whether reorder is required
 |     // provide memory buffers and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); |     mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); |     mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); | ||||||
| 
 | 
 | ||||||
|     // bias
 |     // bias
 | ||||||
|     if(bias != nullptr) { |     if(bias != nullptr) { | ||||||
| @ -135,13 +146,13 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, | |||||||
| static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, | static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, | ||||||
|                             NDArray *gradI, NDArray *gradW, NDArray *gradB, |                             NDArray *gradI, NDArray *gradW, NDArray *gradB, | ||||||
|                             const int kH, const int kW, const int sH, const int sW, const int pH, const  int pW, const int dH, const int dW, |                             const int kH, const int kW, const int sH, const int sW, const int pH, const  int pW, const int dH, const int dW, | ||||||
|                             const int paddingMode, const int isNCHW) { |                             const int paddingMode, const int isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // weights/gradW [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW]
 |     // mkl support weights/gradW in [oC, iC, kH, kW] format only
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW;       // dH == 1 for causal mode in conv1d
 |     const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW;       // dH == 1 for causal mode in conv1d
 | ||||||
| 
 | 
 | ||||||
| @ -150,8 +161,8 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N | |||||||
|     dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; |     dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; | ||||||
|     dnnl::memory::dims dilation  = { dH-1, dW-1}; |     dnnl::memory::dims dilation  = { dH-1, dW-1}; | ||||||
| 
 | 
 | ||||||
|     auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; |     auto xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; | ||||||
|     dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; |     dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims xDims = {bS, iC, iH, iW}; |     dnnl::memory::dims xDims = {bS, iC, iH, iW}; | ||||||
|     dnnl::memory::dims wDims = {oC, iC, kH, kW}; |     dnnl::memory::dims wDims = {oC, iC, kH, kW}; | ||||||
| @ -163,36 +174,60 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N | |||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); |     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(input, 4, x_user_md); |     mkldnnUtils::setBlockStrides(input, x_user_md); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); |     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); | ||||||
|     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { | ||||||
|     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3);   // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
 |         w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); |         uint i0, i1, i2, i3; | ||||||
|     w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); |         if(0 == wFormat) { | ||||||
|     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); |             i0 = 3; i1 = 2; i2 = 0; i3 = 1;     // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
 | ||||||
|  |         } | ||||||
|  |         else if(1 == wFormat) { | ||||||
|  |             i0 = 0; i1 = 1; i2 = 2; i3 = 3; | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             i0 = 0; i1 = 3; i2 = 1; i3 = 2;     // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
 | ||||||
|  |         } | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     // gradO
 |     // gradO
 | ||||||
|     dnnl::memory::desc gradO_mkl_md  = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradO_mkl_md  = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); |     dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); |     mkldnnUtils::setBlockStrides(gradO, gradO_user_md); | ||||||
|      | 
 | ||||||
|     // gradI
 |     // gradI
 | ||||||
|     dnnl::memory::desc gradI_mkl_md  = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradI_mkl_md  = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); |     dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); |     mkldnnUtils::setBlockStrides(gradI, gradI_user_md); | ||||||
|      | 
 | ||||||
|     // gradW
 |     // gradW
 | ||||||
|     dnnl::memory::desc gradW_mkl_md  = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradW_mkl_md  = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat); |     dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); | ||||||
|     gradW_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) { | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3);   // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
 |         gradW_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(2); |         uint i0, i1, i2, i3; | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); |         if(0 == wFormat) { | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); |             i0 = 3; i1 = 2; i2 = 0; i3 = 1;     // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
 | ||||||
|  |         } | ||||||
|  |         else if(1 == wFormat) { | ||||||
|  |             i0 = 0; i1 = 1; i2 = 2; i3 = 3; | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             i0 = 0; i1 = 3; i2 = 1; i3 = 2;     // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
 | ||||||
|  |         } | ||||||
|  |         gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); | ||||||
|  |         gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); | ||||||
|  |         gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); | ||||||
|  |         gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     // gradB
 |     // gradB
 | ||||||
|     dnnl::memory::desc gradB_mkl_md; |     dnnl::memory::desc gradB_mkl_md; | ||||||
| @ -221,10 +256,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N | |||||||
|     // provide memory buffers and check whether reorder is required
 |     // provide memory buffers and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md,  op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); |     mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md,  op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|      mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md,  op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); |      mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); | ||||||
| 
 | 
 | ||||||
|     // gradO
 |     // gradO
 | ||||||
|     auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); |     auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); | ||||||
| @ -489,7 +524,7 @@ static void conv2dBpMKLDNN(sd::graph::Context &block, | |||||||
| PLATFORM_IMPL(conv2d, ENGINE_CPU) { | PLATFORM_IMPL(conv2d, ENGINE_CPU) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, oC] always
 |     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 | ||||||
| 
 | 
 | ||||||
|     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 |     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 | ||||||
| @ -500,24 +535,25 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) { | |||||||
|     int pW = INT_ARG(5);                                                        // paddings width
 |     int pW = INT_ARG(5);                                                        // paddings width
 | ||||||
|     int dH = INT_ARG(6);                                                        // dilations height
 |     int dH = INT_ARG(6);                                                        // dilations height
 | ||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int paddingMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME
 | ||||||
|     bool isNCHW    = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 |     bool isNCHW    = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
 |     int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
 | ||||||
|     int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
 |     int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); |     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong>  expectedWeightsShape = {kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| 
 | 
 | ||||||
|     conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); |     conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -536,12 +572,12 @@ PLATFORM_CHECK(conv2d, ENGINE_CPU) { | |||||||
| PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { | PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC] always
 |     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_NULLIFIED(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 |     auto gradI = OUTPUT_NULLIFIED(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 | ||||||
|     auto gradW = OUTPUT_NULLIFIED(1);                                                 // [kH, kW, iC, oC] always
 |     auto gradW = OUTPUT_NULLIFIED(1);                                                 // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
|     int kH = INT_ARG(0);                                                        // filter(kernel) height
 |     int kH = INT_ARG(0);                                                        // filter(kernel) height
 | ||||||
| @ -554,10 +590,11 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME
 |     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     int trueoH, trueoW;          // true output height, width
 |     int trueoH, trueoW;          // true output height, width
 | ||||||
|     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); |     ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); | ||||||
| @ -566,13 +603,13 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { | |||||||
|         ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); |         ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CONV2D_BP MKLDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CONV2D_BP MKLDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if(bias) |     if(bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D_BP MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D_BP MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| 
 | 
 | ||||||
|     conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); |     conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
|  | |||||||
| @ -40,13 +40,13 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, | |||||||
|                         const int sD, const int sH, const int sW, |                         const int sD, const int sH, const int sW, | ||||||
|                         const int pD, const int pH, const int pW, |                         const int pD, const int pH, const int pW, | ||||||
|                         const int dD, const int dH, const int dW, |                         const int dD, const int dH, const int dW, | ||||||
|                         const int paddingMode, const int isNCDHW) { |                         const int paddingMode, const int isNCDHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // weights [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW]
 |     // mkl support weights  in [oC, iC, kD, kH, kW] format only
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW;       // dH == 1 for causal mode in conv1d
 |     // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW;       // dH == 1 for causal mode in conv1d
 | ||||||
| 
 | 
 | ||||||
| @ -56,8 +56,8 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, | |||||||
|     dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; |     dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; | ||||||
|     dnnl::memory::dims dilation  = {dD-1, dH-1, dW-1}; |     dnnl::memory::dims dilation  = {dD-1, dH-1, dW-1}; | ||||||
| 
 | 
 | ||||||
|     auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; |     auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; | ||||||
|     dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; |     dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; |     dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; | ||||||
|     dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; |     dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; | ||||||
| @ -69,18 +69,30 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, | |||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); |     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(input, 5, x_user_md); |     mkldnnUtils::setBlockStrides(input, x_user_md); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); |     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); | ||||||
|     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { | ||||||
|     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4);   // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
 |         w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); |         uint i0, i1, i2, i3, i4; | ||||||
|     w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); |         if(0 == wFormat) { | ||||||
|     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); |             i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2;     // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); |         } | ||||||
|  |         else if(1 == wFormat) { | ||||||
|  |             i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4; | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3;     // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
 | ||||||
|  |         } | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     // bias
 |     // bias
 | ||||||
|     dnnl::memory::desc b_mkl_md; |     dnnl::memory::desc b_mkl_md; | ||||||
| @ -89,8 +101,8 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, | |||||||
| 
 | 
 | ||||||
|     // output
 |     // output
 | ||||||
|     dnnl::memory::desc z_mkl_md  = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc z_mkl_md  = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); |     dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(output, 5, z_user_md); |     mkldnnUtils::setBlockStrides(output, z_user_md); | ||||||
| 
 | 
 | ||||||
|     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); |     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||||
| 
 | 
 | ||||||
| @ -106,11 +118,11 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, | |||||||
|     // provide memory buffers and check whether reorder is required
 |     // provide memory buffers and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md,  op_prim_desc.src_desc(), DNNL_ARG_SRC); |     mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md,  op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md,  op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); |     mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md,  op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); | ||||||
|   | 
 | ||||||
|     // bias
 |     // bias
 | ||||||
|     if(bias != nullptr) { |     if(bias != nullptr) { | ||||||
|         auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer()); |         auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer()); | ||||||
| @ -140,13 +152,13 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N | |||||||
|                             const int sD, const int sH, const int sW, |                             const int sD, const int sH, const int sW, | ||||||
|                             const int pD, const int pH, const int pW, |                             const int pD, const int pH, const int pW, | ||||||
|                             const int dD, const int dH, const int dW, |                             const int dD, const int dH, const int dW, | ||||||
|                             const int paddingMode, const int isNCDHW) { |                             const int paddingMode, const int isNCDHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // weights/gradW [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW]
 |     // mkl support weights/gradW in [oC, iC, kD, kH, kW] format only
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW;       // dH == 1 for causal mode in conv1d
 |     // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW;       // dH == 1 for causal mode in conv1d
 | ||||||
| 
 | 
 | ||||||
| @ -156,8 +168,8 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N | |||||||
|     dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; |     dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; | ||||||
|     dnnl::memory::dims dilation  = {dD-1, dH-1, dW-1}; |     dnnl::memory::dims dilation  = {dD-1, dH-1, dW-1}; | ||||||
| 
 | 
 | ||||||
|     auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; |     auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; | ||||||
|     dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; |     dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; |     dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; | ||||||
|     dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; |     dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; | ||||||
| @ -169,40 +181,64 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N | |||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); |     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(input, 5, x_user_md); |     mkldnnUtils::setBlockStrides(input, x_user_md); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); |     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); | ||||||
|     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { | ||||||
|     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4);   // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
 |         w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); |         uint i0, i1, i2, i3, i4; | ||||||
|     w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); |         if(0 == wFormat) { | ||||||
|     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); |             i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2;     // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); |         } | ||||||
|  |         else if(1 == wFormat) { | ||||||
|  |             i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4; | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3;     // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
 | ||||||
|  |         } | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); | ||||||
|  |         w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     // gradO
 |     // gradO
 | ||||||
|     dnnl::memory::desc gradO_mkl_md  = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradO_mkl_md  = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); |     dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); | ||||||
| 
 | 
 | ||||||
|     mkldnnUtils::setBlockStrides(gradO, 5, gradO_user_md); |     mkldnnUtils::setBlockStrides(gradO, gradO_user_md); | ||||||
| 
 | 
 | ||||||
|     // gradI
 |     // gradI
 | ||||||
|     dnnl::memory::desc gradI_mkl_md  = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradI_mkl_md  = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); |     dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); | ||||||
| 
 | 
 | ||||||
|     mkldnnUtils::setBlockStrides(gradI, 5, gradI_user_md); |     mkldnnUtils::setBlockStrides(gradI, gradI_user_md); | ||||||
| 
 | 
 | ||||||
|     // gradW
 |     // gradW
 | ||||||
|     dnnl::memory::desc gradW_mkl_md  = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradW_mkl_md  = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat); |     dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); | ||||||
|     gradW_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) { | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(4);   // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
 |         gradW_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3); |         uint i0, i1, i2, i3, i4; | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); |         if(0 == wFormat) { | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); |             i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2;     // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
 | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2); |         } | ||||||
|  |         else if(1 == wFormat) { | ||||||
|  |             i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4; | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3;     // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
 | ||||||
|  |         } | ||||||
|  |         gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); | ||||||
|  |         gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); | ||||||
|  |         gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); | ||||||
|  |         gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); | ||||||
|  |         gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     // gradB
 |     // gradB
 | ||||||
|     dnnl::memory::desc gradB_mkl_md; |     dnnl::memory::desc gradB_mkl_md; | ||||||
| @ -231,10 +267,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N | |||||||
|     // provide memory buffers and check whether reorder is required
 |     // provide memory buffers and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md,  op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); |     mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md,  op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md,  op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); |     mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md,  op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); | ||||||
| 
 | 
 | ||||||
|     // gradO
 |     // gradO
 | ||||||
|     auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); |     auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); | ||||||
| @ -486,7 +522,7 @@ static void conv3dBpMKLDNN(sd::graph::Context &block, | |||||||
| PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { | PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { | ||||||
| 
 | 
 | ||||||
|     auto input = INPUT_VARIABLE(0);                                    // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 |     auto input = INPUT_VARIABLE(0);                                    // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                  // [kD, kH, kW, iC, oC] always
 |     auto weights = INPUT_VARIABLE(1);                                  // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
 | ||||||
|     auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;       // [oC]
 |     auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;       // [oC]
 | ||||||
|     auto output = OUTPUT_VARIABLE(0);                                  // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
 |     auto output = OUTPUT_VARIABLE(0);                                  // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
 | ||||||
| 
 | 
 | ||||||
| @ -507,12 +543,13 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { | |||||||
|     int dW = INT_ARG(11);                                                       // dilations width
 |     int dW = INT_ARG(11);                                                       // dilations width
 | ||||||
|     int paddingMode = INT_ARG(12);                                               // 0-SAME,  1-VALID
 |     int paddingMode = INT_ARG(12);                                               // 0-SAME,  1-VALID
 | ||||||
|     int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;        // INT_ARG(13): 1-NDHWC, 0-NCDHW
 |     int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;        // INT_ARG(13): 1-NDHWC, 0-NCDHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0;         // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| @ -520,7 +557,7 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { | |||||||
|     if (paddingMode)                       // SAME
 |     if (paddingMode)                       // SAME
 | ||||||
|         ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); |         ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); | ||||||
| 
 | 
 | ||||||
|     conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW); |     conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -538,12 +575,12 @@ PLATFORM_CHECK(conv3dnew, ENGINE_CPU) { | |||||||
| PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { | PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { | ||||||
| 
 | 
 | ||||||
|     auto input = INPUT_VARIABLE(0);                                                 // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 |     auto input = INPUT_VARIABLE(0);                                                 // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                               // [kD, kH, kW, iC, oC] always
 |     auto weights = INPUT_VARIABLE(1);                                               // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
 | ||||||
|     auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                    // [oC]
 |     auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                    // [oC]
 | ||||||
|     auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);         // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 |     auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);         // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_NULLIFIED(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
 |     auto gradI = OUTPUT_NULLIFIED(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
 | ||||||
|     auto gradW = OUTPUT_NULLIFIED(1);                                                // [kD, kH, kW, iC, oC] always
 |     auto gradW = OUTPUT_NULLIFIED(1);                                                // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr;                  // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr;                  // [oC]
 | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); |     REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); | ||||||
| @ -564,10 +601,11 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { | |||||||
|     int dW = INT_ARG(11);                                                       // dilations width
 |     int dW = INT_ARG(11);                                                       // dilations width
 | ||||||
|     int paddingMode = INT_ARG(12);                                              // 1-SAME,  0-VALID
 |     int paddingMode = INT_ARG(12);                                              // 1-SAME,  0-VALID
 | ||||||
|     int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;        // INT_ARG(13): 1-NDHWC, 0-NCDHW
 |     int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;        // INT_ARG(13): 1-NDHWC, 0-NCDHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0;         // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     if(paddingMode)                       // SAME
 |     if(paddingMode)                       // SAME
 | ||||||
|         ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); |         ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); | ||||||
| @ -576,26 +614,26 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { | |||||||
|     ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); |     ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}); |     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| 
 | 
 | ||||||
|     conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW); |     conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) { | PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) { | ||||||
|     auto input = INPUT_VARIABLE(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 |     auto input = INPUT_VARIABLE(0);                                                 // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kD, kH, kW, iC, oC] always
 |     auto weights = INPUT_VARIABLE(1);                                               // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
 | ||||||
|     auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 |     auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                    // [oC]
 | ||||||
|     auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 |     auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);         // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
 |     auto gradI = OUTPUT_VARIABLE(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
 | ||||||
|     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kD, kH, kW, iC, oC] always
 |     auto gradW = OUTPUT_VARIABLE(1);                                                // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                  // [oC]
 | ||||||
| 
 | 
 | ||||||
|     return block.isUseMKLDNN() && |     return block.isUseMKLDNN() && | ||||||
|            sd::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB}); |            sd::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB}); | ||||||
|  | |||||||
| @ -34,19 +34,30 @@ namespace platforms { | |||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
| static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, | static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, | ||||||
|                             const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, |                             const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, | ||||||
|                             const int paddingMode, const bool isNCHW) { |                             const int paddingMode, const bool isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // weights [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation
 |     // mkl supports weights format [oC, iC, kH, kW] only
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims strides   = { sH, sW }; |     dnnl::memory::dims strides   = { sH, sW }; | ||||||
|     dnnl::memory::dims padding   = { pH, pW }; |     dnnl::memory::dims padding   = { pH, pW }; | ||||||
|     dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; |     dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; | ||||||
|     dnnl::memory::dims dilation  = { dH-1, dW-1 }; |     dnnl::memory::dims dilation  = { dH-1, dW-1 }; | ||||||
| 
 | 
 | ||||||
|  |     uint i0, i1, i2, i3; | ||||||
|  |     if(0 == wFormat) { | ||||||
|  |         i0 = 2; i1 = 3; i2 = 0; i3 = 1;     // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
 | ||||||
|  |     } | ||||||
|  |     else if(1 == wFormat) { | ||||||
|  |         i0 = 1; i1 = 0; i2 = 2; i3 = 3;     // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
 | ||||||
|  |     } | ||||||
|  |     else { | ||||||
|  |         i0 = 3; i1 = 0; i2 = 1; i3 = 2;     // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     // input type
 |     // input type
 | ||||||
|     dnnl::memory::data_type xType; |     dnnl::memory::data_type xType; | ||||||
|     if(input->dataType() == DataType::FLOAT32) |     if(input->dataType() == DataType::FLOAT32) | ||||||
| @ -76,8 +87,8 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N | |||||||
|     else |     else | ||||||
|         zType = dnnl::memory::data_type::s32; |         zType = dnnl::memory::data_type::s32; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; |     dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; | ||||||
|     dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; |     dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims xDims = {bS, iC, iH, iW}; |     dnnl::memory::dims xDims = {bS, iC, iH, iW}; | ||||||
|     dnnl::memory::dims wDims = {oC, iC, kH, kW}; |     dnnl::memory::dims wDims = {oC, iC, kH, kW}; | ||||||
| @ -87,17 +98,17 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N | |||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); |     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); |     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(input, 4, x_user_md); |     mkldnnUtils::setBlockStrides(input, x_user_md); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); |     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); |     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); | ||||||
|     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2);  // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
 |     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); | ||||||
|     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); |     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); | ||||||
|     w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); |     w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); | ||||||
|     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); |     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); | ||||||
| 
 | 
 | ||||||
|     // bias
 |     // bias
 | ||||||
|     dnnl::memory::desc b_mkl_md; |     dnnl::memory::desc b_mkl_md; | ||||||
| @ -106,8 +117,8 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N | |||||||
| 
 | 
 | ||||||
|     // output
 |     // output
 | ||||||
|     dnnl::memory::desc z_mkl_md  = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); |     dnnl::memory::desc z_mkl_md  = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat); |     dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(output, 4, z_user_md); |     mkldnnUtils::setBlockStrides(output, z_user_md); | ||||||
| 
 | 
 | ||||||
|     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); |     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||||
| 
 | 
 | ||||||
| @ -124,10 +135,10 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N | |||||||
|     // provide memory buffers and check whether reorder is required
 |     // provide memory buffers and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md,  op_prim_desc.src_desc(), DNNL_ARG_SRC); |     mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md,  op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md,  op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); |     mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md,  op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); | ||||||
| 
 | 
 | ||||||
|     // bias
 |     // bias
 | ||||||
|     if(bias != nullptr) { |     if(bias != nullptr) { | ||||||
| @ -156,19 +167,30 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N | |||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
| static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, | static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, | ||||||
|                                     const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, |                                     const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, | ||||||
|                                     const int paddingMode, const bool isNCHW) { |                                     const int paddingMode, const bool isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // weights and gradW [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation
 |     // mkl supports weights/gradW in [oC, iC, kH, kW] format only
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims strides   = { sH, sW }; |     dnnl::memory::dims strides   = { sH, sW }; | ||||||
|     dnnl::memory::dims padding   = { pH, pW }; |     dnnl::memory::dims padding   = { pH, pW }; | ||||||
|     dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; |     dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; | ||||||
|     dnnl::memory::dims dilation  = { dH-1, dW-1 }; |     dnnl::memory::dims dilation  = { dH-1, dW-1 }; | ||||||
| 
 | 
 | ||||||
|  |     uint i0, i1, i2, i3; | ||||||
|  |     if(0 == wFormat) { | ||||||
|  |         i0 = 2; i1 = 3; i2 = 0; i3 = 1;     // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
 | ||||||
|  |     } | ||||||
|  |     else if(1 == wFormat) { | ||||||
|  |         i0 = 1; i1 = 0; i2 = 2; i3 = 3;     // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
 | ||||||
|  |     } | ||||||
|  |     else { | ||||||
|  |         i0 = 3; i1 = 0; i2 = 1; i3 = 2;     // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     // input type
 |     // input type
 | ||||||
|     dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; |     dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; | ||||||
|     // weights type
 |     // weights type
 | ||||||
| @ -182,8 +204,8 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const | |||||||
|     // gradB type
 |     // gradB type
 | ||||||
|     dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; |     dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; |     dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; | ||||||
|     dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; |     dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims xDims = {bS, iC, iH, iW}; |     dnnl::memory::dims xDims = {bS, iC, iH, iW}; | ||||||
|     dnnl::memory::dims wDims = {oC, iC, kH, kW}; |     dnnl::memory::dims wDims = {oC, iC, kH, kW}; | ||||||
| @ -193,36 +215,36 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const | |||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); |     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); |     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(input, 4, x_user_md); |     mkldnnUtils::setBlockStrides(input, x_user_md); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); |     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); |     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); | ||||||
|     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2);  // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
 |     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); | ||||||
|     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); |     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); | ||||||
|     w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); |     w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); | ||||||
|     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); |     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); | ||||||
| 
 | 
 | ||||||
|     // gradO
 |     // gradO
 | ||||||
|     dnnl::memory::desc gradO_mkl_md  = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradO_mkl_md  = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); |     dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); |     mkldnnUtils::setBlockStrides(gradO, gradO_user_md); | ||||||
| 
 | 
 | ||||||
|     // gradI
 |     // gradI
 | ||||||
|     dnnl::memory::desc gradI_mkl_md  = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradI_mkl_md  = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); |     dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); |     mkldnnUtils::setBlockStrides(gradI, gradI_user_md); | ||||||
| 
 | 
 | ||||||
|     // gradW
 |     // gradW
 | ||||||
|     dnnl::memory::desc gradW_mkl_md  = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradW_mkl_md  = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat); |     dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl); | ||||||
|     gradW_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     gradW_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(2);  // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
 |     gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3); |     gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); |     gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); |     gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); | ||||||
| 
 | 
 | ||||||
|     // gradB
 |     // gradB
 | ||||||
|     dnnl::memory::desc gradB_mkl_md; |     dnnl::memory::desc gradB_mkl_md; | ||||||
| @ -251,10 +273,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const | |||||||
|     // provide memory buffers and check whether reorder is required
 |     // provide memory buffers and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md,  op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); |     mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); |     mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); | ||||||
| 
 | 
 | ||||||
|     // gradO
 |     // gradO
 | ||||||
|     auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); |     auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); | ||||||
| @ -311,7 +333,7 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const | |||||||
| PLATFORM_IMPL(deconv2d, ENGINE_CPU) { | PLATFORM_IMPL(deconv2d, ENGINE_CPU) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, oC, iC] always
 |     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
 | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 | ||||||
| 
 | 
 | ||||||
|     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 |     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 | ||||||
| @ -327,14 +349,15 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) { | |||||||
|     int pW = INT_ARG(5);                                                        // paddings width
 |     int pW = INT_ARG(5);                                                        // paddings width
 | ||||||
|     int dH = INT_ARG(6);                                                        // dilations height
 |     int dH = INT_ARG(6);                                                        // dilations height
 | ||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int paddingMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW     = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 |     int isNCHW     = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| @ -344,7 +367,7 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) { | |||||||
|         ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); |         ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); |     deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -377,12 +400,12 @@ PLATFORM_CHECK(deconv2d, ENGINE_CPU) { | |||||||
| PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { | PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
 |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, oC, iC] always
 |     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
 | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
 |     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
 | ||||||
|     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kH, kW, oC, iC] always
 |     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(input->rankOf()   == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); |     REQUIRE_TRUE(input->rankOf()   == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); | ||||||
| @ -398,18 +421,19 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { | |||||||
|     int pW = INT_ARG(5);                                                        // paddings width
 |     int pW = INT_ARG(5);                                                        // paddings width
 | ||||||
|     int dH = INT_ARG(6);                                                        // dilations height
 |     int dH = INT_ARG(6);                                                        // dilations height
 | ||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int paddingMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     int trueoH, trueoW;          // true output height, width
 |     int trueoH, trueoW;          // true output height, width
 | ||||||
|     ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); |     ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape  = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); |     std::vector<Nd4jLong> expectedGradOShape  = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if(bias) |     if(bias) | ||||||
| @ -420,19 +444,19 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { | |||||||
|         ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); |         ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); |     deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| PLATFORM_CHECK(deconv2d_bp, ENGINE_CPU) { | PLATFORM_CHECK(deconv2d_bp, ENGINE_CPU) { | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
 |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, oC, iC] always
 |     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
 | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
 |     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
 | ||||||
|     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kH, kW, oC, iC] always
 |     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
|     int dH = INT_ARG(6);                                                        // dilations height
 |     int dH = INT_ARG(6);                                                        // dilations height
 | ||||||
|  | |||||||
| @ -34,7 +34,7 @@ namespace platforms { | |||||||
| static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI, | static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI, | ||||||
|                                     const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW, |                                     const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW, | ||||||
|                                     const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, |                                     const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, | ||||||
|                                     const bool isNCHW) { |                                     const bool isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format
 |     // gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format
 | ||||||
|     // weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC]
 |     // weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC]
 | ||||||
| @ -52,8 +52,8 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad | |||||||
|     // gradI type
 |     // gradI type
 | ||||||
|     dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; |     dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; |     dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; | ||||||
|     dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; |     dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims xDims = {bS, iC, iH, iW}; |     dnnl::memory::dims xDims = {bS, iC, iH, iW}; | ||||||
|     dnnl::memory::dims wDims = {oC, iC, kH, kW}; |     dnnl::memory::dims wDims = {oC, iC, kH, kW}; | ||||||
| @ -66,7 +66,7 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad | |||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); |     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); |     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); | ||||||
|     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3);   // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
 |     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3);   // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); |     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); | ||||||
| @ -75,13 +75,13 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad | |||||||
| 
 | 
 | ||||||
|     // gradO
 |     // gradO
 | ||||||
|     dnnl::memory::desc gradO_mkl_md  = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradO_mkl_md  = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); |     dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); |     mkldnnUtils::setBlockStrides(gradO, gradO_user_md); | ||||||
| 
 | 
 | ||||||
|     // gradI
 |     // gradI
 | ||||||
|     dnnl::memory::desc gradI_mkl_md  = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradI_mkl_md  = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); |     dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); |     mkldnnUtils::setBlockStrides(gradI, gradI_user_md); | ||||||
| 
 | 
 | ||||||
|     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); |     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||||
| 
 | 
 | ||||||
| @ -101,10 +101,10 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad | |||||||
|     // provide memory buffers and check whether reorder is required
 |     // provide memory buffers and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md,  op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); |     mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md,  op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); | ||||||
| 
 | 
 | ||||||
|     // gradO
 |     // gradO
 | ||||||
|     mkldnnUtils::loadDataToMklStream(gradO, engine, stream, args, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); |     mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); | ||||||
| 
 | 
 | ||||||
|     // gradI
 |     // gradI
 | ||||||
|     auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); |     auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); | ||||||
| @ -128,10 +128,10 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad | |||||||
| PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { | PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { | ||||||
| 
 | 
 | ||||||
|     auto gradO      = INPUT_VARIABLE(2);                                                // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 |     auto gradO      = INPUT_VARIABLE(2);                                                // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 | ||||||
|     auto weights    = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC] always
 |     auto weights    = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto gradIShape = INPUT_VARIABLE(0);                                                // [4] - shape of input of conv2d (that is shape of gradI)
 |     auto gradIShape = INPUT_VARIABLE(0);                                                // [4] - shape of input of conv2d (that is shape of gradI)
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE(0);                                                  // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 |     auto gradI = OUTPUT_VARIABLE(0);                                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 | ||||||
| 
 | 
 | ||||||
|     int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
 |     int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
 | ||||||
|     int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
 |     int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
 | ||||||
| @ -143,6 +143,7 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     const int rank = gradO->rankOf(); |     const int rank = gradO->rankOf(); | ||||||
| 
 | 
 | ||||||
| @ -188,7 +189,7 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { | |||||||
|     //     gradO = new NDArray(gradO->permute({0,3,1,2}));    // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
 |     //     gradO = new NDArray(gradO->permute({0,3,1,2}));    // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
 | ||||||
|     // }
 |     // }
 | ||||||
| 
 | 
 | ||||||
|     deconv2TFdBackPropMKLDNN(weights, gradO,  gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); |     deconv2TFdBackPropMKLDNN(weights, gradO,  gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     // delete weights;
 |     // delete weights;
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -35,19 +35,30 @@ namespace platforms { | |||||||
| static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, | static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, | ||||||
|                             const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, |                             const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, | ||||||
|                             const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, |                             const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, | ||||||
|                             const bool isNCDHW) { |                             const bool isNCDHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // weights [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation
 |     // mkl supports weights in [oC, iC, kD, kH, kW] only
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims strides   = { sD, sH, sW }; |     dnnl::memory::dims strides   = { sD, sH, sW }; | ||||||
|     dnnl::memory::dims padding   = { pD, pH, pW }; |     dnnl::memory::dims padding   = { pD, pH, pW }; | ||||||
|     dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; |     dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; | ||||||
|     dnnl::memory::dims dilation  = { dD-1, dH-1, dW-1 }; |     dnnl::memory::dims dilation  = { dD-1, dH-1, dW-1 }; | ||||||
| 
 | 
 | ||||||
|  |     uint i0, i1, i2, i3, i4; | ||||||
|  |     if(0 == wFormat) { | ||||||
|  |         i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2;     // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
 | ||||||
|  |     } | ||||||
|  |     else if(1 == wFormat) { | ||||||
|  |         i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4;     // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
 | ||||||
|  |     } | ||||||
|  |     else { | ||||||
|  |         i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3;     // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     // input type
 |     // input type
 | ||||||
|     dnnl::memory::data_type xType; |     dnnl::memory::data_type xType; | ||||||
|     if(input->dataType() == DataType::FLOAT32) |     if(input->dataType() == DataType::FLOAT32) | ||||||
| @ -77,8 +88,8 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N | |||||||
|     else |     else | ||||||
|         zType = dnnl::memory::data_type::s32; |         zType = dnnl::memory::data_type::s32; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; |     dnnl::memory::format_tag xFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; | ||||||
|     dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; |     dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; |     dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; | ||||||
|     dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; |     dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; | ||||||
| @ -88,18 +99,18 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N | |||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); |     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); |     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(input, 5, x_user_md); |     mkldnnUtils::setBlockStrides(input, x_user_md); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); |     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); |     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); | ||||||
|     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3);   // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
 |     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); | ||||||
|     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4); |     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); | ||||||
|     w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); |     w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); | ||||||
|     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); |     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); | ||||||
|     w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); |     w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); | ||||||
| 
 | 
 | ||||||
|     // bias
 |     // bias
 | ||||||
|     dnnl::memory::desc b_mkl_md; |     dnnl::memory::desc b_mkl_md; | ||||||
| @ -108,8 +119,8 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N | |||||||
| 
 | 
 | ||||||
|     // output
 |     // output
 | ||||||
|     dnnl::memory::desc z_mkl_md  = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); |     dnnl::memory::desc z_mkl_md  = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat); |     dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(output, 5, z_user_md); |     mkldnnUtils::setBlockStrides(output, z_user_md); | ||||||
| 
 | 
 | ||||||
|     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); |     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||||
| 
 | 
 | ||||||
| @ -126,10 +137,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N | |||||||
|     // provide memory buffers and check whether reorder is required
 |     // provide memory buffers and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md,  op_prim_desc.src_desc(), DNNL_ARG_SRC); |     mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md,  op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md,  op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); |     mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md,  op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); | ||||||
| 
 | 
 | ||||||
|     // bias
 |     // bias
 | ||||||
|     if(bias != nullptr) { |     if(bias != nullptr) { | ||||||
| @ -161,19 +172,30 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, | |||||||
|                                     const int sD, const int sH, const int sW, |                                     const int sD, const int sH, const int sW, | ||||||
|                                     const int pD, const int pH, const int pW, |                                     const int pD, const int pH, const int pW, | ||||||
|                                     const int dD, const int dH, const int dW, |                                     const int dD, const int dH, const int dW, | ||||||
|                                     const bool isNCDHW) { |                                     const bool isNCDHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation
 |     // mkl supports weights/gradW in [oC, iC, kD, kH, kW] format only
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims strides   = { sD, sH, sW }; |     dnnl::memory::dims strides   = { sD, sH, sW }; | ||||||
|     dnnl::memory::dims padding   = { pD, pH, pW }; |     dnnl::memory::dims padding   = { pD, pH, pW }; | ||||||
|     dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; |     dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; | ||||||
|     dnnl::memory::dims dilation  = { dD-1, dH-1, dW-1 }; |     dnnl::memory::dims dilation  = { dD-1, dH-1, dW-1 }; | ||||||
| 
 | 
 | ||||||
|  |     uint i0, i1, i2, i3, i4; | ||||||
|  |     if(0 == wFormat) { | ||||||
|  |         i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2;     // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
 | ||||||
|  |     } | ||||||
|  |     else if(1 == wFormat) { | ||||||
|  |         i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4;     // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
 | ||||||
|  |     } | ||||||
|  |     else { | ||||||
|  |         i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3;     // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     // input type
 |     // input type
 | ||||||
|     dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; |     dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; | ||||||
|     // weights type
 |     // weights type
 | ||||||
| @ -187,8 +209,8 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, | |||||||
|     // gradB type
 |     // gradB type
 | ||||||
|     dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; |     dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; |     dnnl::memory::format_tag xFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; | ||||||
|     dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; |     dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; |     dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; | ||||||
|     dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; |     dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; | ||||||
| @ -198,38 +220,38 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, | |||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); |     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); |     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(input, 5, x_user_md); |     mkldnnUtils::setBlockStrides(input, x_user_md); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); |     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); |     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); | ||||||
|     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3);   // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
 |     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); | ||||||
|     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4); |     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); | ||||||
|     w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); |     w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); | ||||||
|     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); |     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); | ||||||
|     w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); |     w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); | ||||||
| 
 | 
 | ||||||
|     // gradO
 |     // gradO
 | ||||||
|     dnnl::memory::desc gradO_mkl_md  = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradO_mkl_md  = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); |     dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(gradO, 5, gradO_user_md); |     mkldnnUtils::setBlockStrides(gradO, gradO_user_md); | ||||||
| 
 | 
 | ||||||
|     // gradI
 |     // gradI
 | ||||||
|     dnnl::memory::desc gradI_mkl_md  = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradI_mkl_md  = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); |     dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(gradI, 5, gradI_user_md); |     mkldnnUtils::setBlockStrides(gradI, gradI_user_md); | ||||||
| 
 | 
 | ||||||
|     // gradW
 |     // gradW
 | ||||||
|     dnnl::memory::desc gradW_mkl_md  = dnnl::memory::desc(wDims, gradWType, wFormat); |     dnnl::memory::desc gradW_mkl_md  = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat); |     dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl); | ||||||
|     gradW_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     gradW_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3);   // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
 |     gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(4); |     gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); |     gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); |     gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2); |     gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4); | ||||||
| 
 | 
 | ||||||
|     // gradB
 |     // gradB
 | ||||||
|     dnnl::memory::desc gradB_mkl_md; |     dnnl::memory::desc gradB_mkl_md; | ||||||
| @ -259,10 +281,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, | |||||||
|     // provide memory buffers and check whether reorder is required
 |     // provide memory buffers and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md,  op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); |     mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md,  op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); |     mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); | ||||||
| 
 | 
 | ||||||
|     // gradO
 |     // gradO
 | ||||||
|     auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); |     auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); | ||||||
| @ -319,7 +341,7 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, | |||||||
| PLATFORM_IMPL(deconv3d, ENGINE_CPU) { | PLATFORM_IMPL(deconv3d, ENGINE_CPU) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kD, kH, kW, oC, iC] always
 |     auto weights = INPUT_VARIABLE(1);                                    // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
 | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 | ||||||
| 
 | 
 | ||||||
|     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
 |     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
 | ||||||
| @ -341,12 +363,13 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) { | |||||||
|     int dW = INT_ARG(11);                                                           // dilations width
 |     int dW = INT_ARG(11);                                                           // dilations width
 | ||||||
|     int isSameMode = INT_ARG(12);                                                   // 0-SAME,  1-VALID
 |     int isSameMode = INT_ARG(12);                                                   // 0-SAME,  1-VALID
 | ||||||
|     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;           // INT_ARG(13): 1-NDHWC, 0-NCDHW
 |     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;           // INT_ARG(13): 1-NDHWC, 0-NCDHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0;             // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong>  expectedWeightsShape = {kD, kH, kW, oC, iC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| @ -356,7 +379,7 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) { | |||||||
|         ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); |         ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); |     deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -390,12 +413,12 @@ PLATFORM_CHECK(deconv3d, ENGINE_CPU) { | |||||||
| PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { | PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kD, kH, kW, oC, iC] always
 |     auto weights = INPUT_VARIABLE(1);                                                // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
 | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
 |     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
 | ||||||
|     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kD, kH, kW, oC, iC] always
 |     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(input->rankOf()   == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); |     REQUIRE_TRUE(input->rankOf()   == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); | ||||||
| @ -416,17 +439,18 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { | |||||||
|     int dH = INT_ARG(10);                                                       // dilations height
 |     int dH = INT_ARG(10);                                                       // dilations height
 | ||||||
|     int dW = INT_ARG(11);                                                       // dilations width
 |     int dW = INT_ARG(11);                                                       // dilations width
 | ||||||
|     int isSameMode = INT_ARG(12);                                               // 0-SAME,  1-VALID
 |     int isSameMode = INT_ARG(12);                                               // 0-SAME,  1-VALID
 | ||||||
|     int isNCDHW  = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;       // INT_ARG(13): 1-NDHWC, 0-NCDHW
 |     int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1;        // INT_ARG(13): 1-NDHWC, 0-NCDHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0;         // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;             // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); | ||||||
| 
 | 
 | ||||||
|     int trueoD, trueoH, trueoW;          // true output height, width
 |     int trueoD, trueoH, trueoW;          // true output height, width
 | ||||||
|     ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); |     ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if(bias) |     if(bias) | ||||||
| @ -435,7 +459,7 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { | |||||||
|     if(isSameMode)               // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
 |     if(isSameMode)               // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
 | ||||||
|         ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); |         ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); | ||||||
| 
 | 
 | ||||||
|     deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); |     deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -443,12 +467,12 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { | |||||||
| 
 | 
 | ||||||
| PLATFORM_CHECK(deconv3d_bp, ENGINE_CPU) { | PLATFORM_CHECK(deconv3d_bp, ENGINE_CPU) { | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW)
 |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kD, kH, kW, oC, iC] always
 |     auto weights = INPUT_VARIABLE(1);                                                // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
 | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oD, oH, oW, oC] (NHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oD, oH, oW, oC] (NHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iD, iH, iW, iC] (NHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
 |     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iD, iH, iW, iC] (NHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
 | ||||||
|     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kD, kH, kW, oC, iC] always
 |     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
|     int dD = INT_ARG(9);                                                        // dilations depth
 |     int dD = INT_ARG(9);                                                        // dilations depth
 | ||||||
|  | |||||||
| @ -35,19 +35,19 @@ namespace platforms { | |||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
| static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, | static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, | ||||||
|                                   const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, |                                   const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, | ||||||
|                                   const int paddingMode, const bool isNCHW) { |                                   const int paddingMode, const bool isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // mkl supports only following case: mC = 1, oC = iC
 |     // mkl supports only following case: mC = 1, oC = iC
 | ||||||
| 
 | 
 | ||||||
|     // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given
 |     // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given
 | ||||||
|     // weights [kH, kW, iC, mC], mkl doesn't support this format, so we'll make permute
 |     // weights {iC, mC, 1, kH, kW}
 | ||||||
|     // bias [oC], may be nullptr
 |     // bias [oC], may be nullptr
 | ||||||
|     // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc
 |     // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc
 | ||||||
|     // oC = iC*mC
 |     // oC = iC*mC
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, mC, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;           // corresponding indexes
 |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;           // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weights->sizeAt(indWmC);                                   // channels multiplier
 |     mC = weights->sizeAt(indWmC);                                   // channels multiplier
 | ||||||
| 
 | 
 | ||||||
|     const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW;       // dH == 1 for causal mode in conv1d
 |     const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW;       // dH == 1 for causal mode in conv1d
 | ||||||
| @ -57,6 +57,17 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, | |||||||
|     dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; |     dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; | ||||||
|     dnnl::memory::dims dilation  = { dH-1, dW-1}; |     dnnl::memory::dims dilation  = { dH-1, dW-1}; | ||||||
| 
 | 
 | ||||||
|  |     uint i0, i1, i2, i3; | ||||||
|  |     if(0 == wFormat) { | ||||||
|  |         i0 = 2; i1 = 3; i2 = 0; i3 = 1;     // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]
 | ||||||
|  |     } | ||||||
|  |     else if(1 == wFormat) { | ||||||
|  |         i0 = 1; i1 = 0; i2 = 2; i3 = 3;     // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW]
 | ||||||
|  |     } | ||||||
|  |     else { | ||||||
|  |         i0 = 3; i1 = 0; i2 = 1; i3 = 2;     // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW]
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     // input type
 |     // input type
 | ||||||
|     dnnl::memory::data_type xType; |     dnnl::memory::data_type xType; | ||||||
|     if(input->dataType() == DataType::FLOAT32) |     if(input->dataType() == DataType::FLOAT32) | ||||||
| @ -86,8 +97,8 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, | |||||||
|     else |     else | ||||||
|         zType = dnnl::memory::data_type::s32; |         zType = dnnl::memory::data_type::s32; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; |     dnnl::memory::format_tag xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; | ||||||
|     dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw; |     dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims xDims = {bS, iC, iH, iW}; |     dnnl::memory::dims xDims = {bS, iC, iH, iW}; | ||||||
|     dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; |     dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; | ||||||
| @ -97,18 +108,18 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, | |||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); |     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat); |     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(input, 4, x_user_md); |     mkldnnUtils::setBlockStrides(input, x_user_md); | ||||||
| 
 | 
 | ||||||
|     // weights, make permute [kH, kW, iC, mC] ->  [iC, mC, 1, kH, kW];
 |     // weights
 | ||||||
|     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); |     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); |     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); | ||||||
|     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2);   // permute
 |     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);   // permute
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); |     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); | ||||||
|     w_user_md.data.format_desc.blocking.strides[2] = 0; |     w_user_md.data.format_desc.blocking.strides[2] = 0; | ||||||
|     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(0); |     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2); | ||||||
|     w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(1); |     w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3); | ||||||
| 
 | 
 | ||||||
|     // bias
 |     // bias
 | ||||||
|     dnnl::memory::desc b_mkl_md; |     dnnl::memory::desc b_mkl_md; | ||||||
| @ -117,8 +128,8 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, | |||||||
| 
 | 
 | ||||||
|     // output
 |     // output
 | ||||||
|     dnnl::memory::desc z_mkl_md  = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); |     dnnl::memory::desc z_mkl_md  = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat); |     dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(output, 4, z_user_md); |     mkldnnUtils::setBlockStrides(output, z_user_md); | ||||||
| 
 | 
 | ||||||
|     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); |     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||||
| 
 | 
 | ||||||
| @ -135,10 +146,10 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, | |||||||
|     // provide memory buffers and check whether reorder is required
 |     // provide memory buffers and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); |     mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); |     mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); | ||||||
| 
 | 
 | ||||||
|     // bias
 |     // bias
 | ||||||
|     if(bias != nullptr) { |     if(bias != nullptr) { | ||||||
| @ -166,19 +177,19 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, | |||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
| static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, | static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, | ||||||
|                                     const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, |                                     const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, | ||||||
|                                     const int paddingMode, const bool isNCHW) { |                                     const int paddingMode, const bool isNCHW, const int wFormat) { | ||||||
| 
 | 
 | ||||||
|     // mkl supports only following case: mC = 1, oC = iC
 |     // mkl supports only following case: mC = 1, oC = iC
 | ||||||
| 
 | 
 | ||||||
|     // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given
 |     // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given
 | ||||||
|     // weights, gradW [kH, kW, iC, mC], mkl doesn't support this format, so we'll make permute
 |     // weights/gradW {iC, mC, 1, kH, kW}
 | ||||||
|     // gradB [oC], may be nullptr
 |     // gradB [oC], may be nullptr
 | ||||||
|     // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc
 |     // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc
 | ||||||
|     // oC = iC*mC
 |     // oC = iC*mC
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, mC, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;           // corresponding indexes
 |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;           // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weights->sizeAt(indWmC); |     mC = weights->sizeAt(indWmC); | ||||||
| 
 | 
 | ||||||
|     const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW;       // dH == 1 for causal mode in conv1d
 |     const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW;       // dH == 1 for causal mode in conv1d
 | ||||||
| @ -188,6 +199,17 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w | |||||||
|     dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; |     dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; | ||||||
|     dnnl::memory::dims dilation  = { dH-1, dW-1}; |     dnnl::memory::dims dilation  = { dH-1, dW-1}; | ||||||
| 
 | 
 | ||||||
|  |     uint i0, i1, i2, i3; | ||||||
|  |     if(0 == wFormat) { | ||||||
|  |         i0 = 2; i1 = 3; i2 = 0; i3 = 1;     // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]
 | ||||||
|  |     } | ||||||
|  |     else if(1 == wFormat) { | ||||||
|  |         i0 = 1; i1 = 0; i2 = 2; i3 = 3;     // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW]
 | ||||||
|  |     } | ||||||
|  |     else { | ||||||
|  |         i0 = 3; i1 = 0; i2 = 1; i3 = 2;     // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW]
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     // input type
 |     // input type
 | ||||||
|     dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; |     dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; | ||||||
|     // weights type
 |     // weights type
 | ||||||
| @ -201,8 +223,8 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w | |||||||
|     // gradB type
 |     // gradB type
 | ||||||
|     dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; |     dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; |     dnnl::memory::format_tag xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; | ||||||
|     dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw; |     dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw; | ||||||
| 
 | 
 | ||||||
|     dnnl::memory::dims xDims = {bS, iC, iH, iW}; |     dnnl::memory::dims xDims = {bS, iC, iH, iW}; | ||||||
|     dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; |     dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; | ||||||
| @ -212,38 +234,38 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w | |||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); |     dnnl::memory::desc x_mkl_md  = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat); |     dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(input, 4, x_user_md); |     mkldnnUtils::setBlockStrides(input, x_user_md); | ||||||
| 
 | 
 | ||||||
|     // weights, make permute [kH, kW, iC, mC] ->  [iC, mC, 1, kH, kW];
 |     // weights
 | ||||||
|     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); |     dnnl::memory::desc w_mkl_md  = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); |     dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); | ||||||
|     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     w_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2);   // permute
 |     w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);   // permute
 | ||||||
|     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); |     w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); | ||||||
|     w_user_md.data.format_desc.blocking.strides[2] = 0; |     w_user_md.data.format_desc.blocking.strides[2] = 0; | ||||||
|     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(0); |     w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2); | ||||||
|     w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(1); |     w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3); | ||||||
| 
 | 
 | ||||||
|     // gradO
 |     // gradO
 | ||||||
|     dnnl::memory::desc gradO_mkl_md  = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradO_mkl_md  = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFrmat); |     dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); |     mkldnnUtils::setBlockStrides(gradO, gradO_user_md); | ||||||
| 
 | 
 | ||||||
|     // gradI
 |     // gradI
 | ||||||
|     dnnl::memory::desc gradI_mkl_md  = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradI_mkl_md  = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFrmat); |     dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFormatMkl); | ||||||
|     mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); |     mkldnnUtils::setBlockStrides(gradI, gradI_user_md); | ||||||
| 
 | 
 | ||||||
|     // gradW, make permute [kH, kW, iC, mC] ->  [iC, mC, 1, kH, kW];
 |     // gradW
 | ||||||
|     dnnl::memory::desc gradW_mkl_md  = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); |     dnnl::memory::desc gradW_mkl_md  = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); | ||||||
|     dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat); |     dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl); | ||||||
|     gradW_user_md.data.format_kind = dnnl_blocked;    // overrides format
 |     gradW_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(2);   // permute
 |     gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);   // permute
 | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3); |     gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[2] = 0; |     gradW_user_md.data.format_desc.blocking.strides[2] = 0; | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(0); |     gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i2); | ||||||
|     gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(1); |     gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i3); | ||||||
| 
 | 
 | ||||||
|     // gradB
 |     // gradB
 | ||||||
|     dnnl::memory::desc gradB_mkl_md; |     dnnl::memory::desc gradB_mkl_md; | ||||||
| @ -272,10 +294,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w | |||||||
|     // provide memory buffers and check whether reorder is required
 |     // provide memory buffers and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); |     mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|     // weights
 |     // weights
 | ||||||
|     mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); |     mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); | ||||||
| 
 | 
 | ||||||
|     // gradO
 |     // gradO
 | ||||||
|     auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); |     auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); | ||||||
| @ -332,7 +354,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w | |||||||
| PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) { | PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     auto input   = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, mC] always
 |     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC] = iC*mC
 |     auto bias    = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC] = iC*mC
 | ||||||
| 
 | 
 | ||||||
|     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
 |     auto output  = OUTPUT_VARIABLE(0);                                   // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
 | ||||||
| @ -347,21 +369,22 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME
 |     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW     = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 |     int isNCHW     = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
 |     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
 | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weights->sizeAt(indWmC);                           // channels multiplier
 |     mC = weights->sizeAt(indWmC);                           // channels multiplier
 | ||||||
| 
 | 
 | ||||||
|     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); |     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D MKL OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); |     REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D MKL OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); | ||||||
|     if (bias) |     if (bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| 
 | 
 | ||||||
|     depthwiseConv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); |     depthwiseConv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -394,12 +417,12 @@ PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) { | |||||||
| PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { | PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
 |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, mC] always
 |     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] = [iC*mC]
 |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] = [iC*mC]
 | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_NULLIFIED(0);                                                 // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
 |     auto gradI = OUTPUT_NULLIFIED(0);                                                 // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
 | ||||||
|     auto gradW = OUTPUT_NULLIFIED(1);                                                 // [kH, kW, iC, mC] always
 |     auto gradW = OUTPUT_NULLIFIED(1);                                                 // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(input->rankOf()   == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); |     REQUIRE_TRUE(input->rankOf()   == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); | ||||||
| @ -416,10 +439,11 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { | |||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME
 |     int paddingMode = INT_ARG(8);                                               // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 |     int isNCHW  = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 1-NHWC, 0-NCHW
 | ||||||
|  |     int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0;         // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC]
 | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
 |     int bS, iC, iH, iW, mC, oC, oH, oW;                     // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
 | ||||||
|     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 |     int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH;   // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); | ||||||
|     mC = weights->sizeAt(indWmC);                           // channels multiplier
 |     mC = weights->sizeAt(indWmC);                           // channels multiplier
 | ||||||
| 
 | 
 | ||||||
|     int trueoH, trueoW;          // correct output height, width
 |     int trueoH, trueoW;          // correct output height, width
 | ||||||
| @ -428,13 +452,13 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { | |||||||
|     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); |     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); |     std::vector<Nd4jLong> expectedGradOShape   = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW,  0,indIOioC,indOoH,indOoH+1}); | ||||||
|     std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, mC}; |     std::vector<Nd4jLong> expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0,  "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); |     REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); | ||||||
|     if(bias) |     if(bias) | ||||||
|         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); |         REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); | ||||||
| 
 | 
 | ||||||
|     depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); |     depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -443,12 +467,12 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { | |||||||
| PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) { | PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) { | ||||||
| 
 | 
 | ||||||
|     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
 |     auto input   = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
 | ||||||
|     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, mC] always
 |     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] = [iC*mC]
 |     auto bias    = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC] = [iC*mC]
 | ||||||
|     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 |     auto gradO   = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
 | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
 |     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
 | ||||||
|     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kH, kW, iC, mC] always
 |     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC]
 | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
|     const DataType xType = input->dataType(); |     const DataType xType = input->dataType(); | ||||||
|  | |||||||
| @ -272,14 +272,14 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* | |||||||
| 
 | 
 | ||||||
|     // provide memory and check whether reorder is required
 |     // provide memory and check whether reorder is required
 | ||||||
|     // x
 |     // x
 | ||||||
|     mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, lstm_prim_desc.src_layer_desc(), DNNL_ARG_SRC_LAYER); |     mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]); | ||||||
|   | 
 | ||||||
|     // wx
 |     // wx
 | ||||||
|     mkldnnUtils::loadDataToMklStream(Wx, engine, stream, args, wx_user_md, lstm_prim_desc.weights_layer_desc(), DNNL_ARG_WEIGHTS_LAYER); |     mkldnnUtils::loadDataToMklStream(Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]); | ||||||
| 
 | 
 | ||||||
|     // wr
 |     // wr
 | ||||||
|     mkldnnUtils::loadDataToMklStream(Wr, engine, stream, args, wr_user_md, lstm_prim_desc.weights_iter_desc(), DNNL_ARG_WEIGHTS_ITER); |     mkldnnUtils::loadDataToMklStream(Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]); | ||||||
|      | 
 | ||||||
|     // h
 |     // h
 | ||||||
|     auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer()); |     auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer()); | ||||||
|     const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc(); |     const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc(); | ||||||
| @ -288,17 +288,17 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* | |||||||
| 
 | 
 | ||||||
|     // b
 |     // b
 | ||||||
|     if(b) { |     if(b) { | ||||||
|         mkldnnUtils::loadDataToMklStream(b, engine, stream, args, b_user_md, lstm_prim_desc.bias_desc(), DNNL_ARG_BIAS); |         mkldnnUtils::loadDataToMklStream(b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // hI
 |     // hI
 | ||||||
|     if(hI) { |     if(hI) { | ||||||
|         mkldnnUtils::loadDataToMklStream(hI, engine, stream, args, hI_user_md, lstm_prim_desc.src_iter_desc(), DNNL_ARG_SRC_ITER); |         mkldnnUtils::loadDataToMklStream(hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // cI
 |     // cI
 | ||||||
|     if(cI) { |     if(cI) { | ||||||
|         mkldnnUtils::loadDataToMklStream(cI, engine, stream, args, cI_user_md, lstm_prim_desc.src_iter_c_desc(), DNNL_ARG_SRC_ITER_C); |         mkldnnUtils::loadDataToMklStream(cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     bool hLReorder(false), cLReorder(false); |     bool hLReorder(false), cLReorder(false); | ||||||
|  | |||||||
| @ -163,7 +163,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b | |||||||
|     // provide memory buffers and check whether reorder is required
 |     // provide memory buffers and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     mkldnnUtils::loadDataToMklStream(xTR, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); |     mkldnnUtils::loadDataToMklStream(xTR, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
|     /*
 |     /*
 | ||||||
|     auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->getBuffer()); |     auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->getBuffer()); | ||||||
|     const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); |     const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); | ||||||
| @ -173,7 +173,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b | |||||||
|     args[DNNL_ARG_SRC] = x_mkl_mem; |     args[DNNL_ARG_SRC] = x_mkl_mem; | ||||||
| */ | */ | ||||||
|     // y
 |     // y
 | ||||||
|     mkldnnUtils::loadDataToMklStream(yTR, engine, stream, args, y_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); |     mkldnnUtils::loadDataToMklStream(yTR, engine, stream, y_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); | ||||||
|     /*
 |     /*
 | ||||||
|     auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->getBuffer()); |     auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->getBuffer()); | ||||||
|     const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc(); |     const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc(); | ||||||
|  | |||||||
| @ -60,7 +60,7 @@ PLATFORM_IMPL(maxpool2d, ENGINE_CPU) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     if (paddingMode) |     if (paddingMode) | ||||||
|         ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); |         ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); | ||||||
| @ -102,7 +102,7 @@ PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
|     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 |     int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); |     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|  | |||||||
| @ -60,7 +60,7 @@ PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;                     // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     if(paddingMode)                       // SAME
 |     if(paddingMode)                       // SAME
 | ||||||
|         ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); |         ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); | ||||||
| @ -107,7 +107,7 @@ PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) { | |||||||
| 
 | 
 | ||||||
|     int bS, iC, iD, iH, iW, oC, oD, oH, oW;               // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 |     int bS, iC, iD, iH, iW, oC, oD, oH, oW;               // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
 | ||||||
|     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes
 |     int indIOioC, indIOioD, indWoC, indWiC, indWkD;       // corresponding indexes
 | ||||||
|     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); |     ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); |     std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW,  0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); | ||||||
|     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); |     REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); | ||||||
|  | |||||||
| @ -30,7 +30,7 @@ namespace mkldnnUtils { | |||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////
 | ||||||
| void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims){ | void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims){ | ||||||
|      | 
 | ||||||
|     std::vector<int64_t> vDims(rank); |     std::vector<int64_t> vDims(rank); | ||||||
|     for (auto i = 0; i < rank; i++) { |     for (auto i = 0; i < rank; i++) { | ||||||
|         vDims[i] = array->sizeAt(i); |         vDims[i] = array->sizeAt(i); | ||||||
| @ -56,25 +56,27 @@ dnnl::memory::format_tag   getFormat(const int rank){ | |||||||
|         } |         } | ||||||
|         return dnnl::memory::format_tag::a; // 1 == dataSetRank
 |         return dnnl::memory::format_tag::a; // 1 == dataSetRank
 | ||||||
| } | } | ||||||
|  | 
 | ||||||
| //////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////
 | ||||||
| void   setBlockStrides(const NDArray* array, const int rank, dnnl::memory::desc& mklMd){ | void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd){ | ||||||
|         if (array->ews() != 1 || array->ordering() != 'c') { | 
 | ||||||
|             mklMd.data.format_kind = dnnl_blocked;    // overrides format
 |     if (array->ews() != 1 || array->ordering() != 'c') { | ||||||
|             for (auto i = 0; i < rank; ++i) { |         mklMd.data.format_kind = dnnl_blocked;    // overrides format
 | ||||||
|                 mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i); |         for (auto i = 0; i < array->rankOf(); ++i) { | ||||||
|             } |             mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i); | ||||||
|         } |         } | ||||||
|  |     } | ||||||
| } | } | ||||||
| ////////////////////////////////////////////////////////////////////////////////////////////////
 | ////////////////////////////////////////////////////////////////////////////////////////////////
 | ||||||
| void loadDataToMklStream(const NDArray* array, dnnl::engine& engine, dnnl::stream& stream,  | void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, | ||||||
|                          std::unordered_map<int, dnnl::memory>& args, dnnl::memory::desc& user_md, dnnl::memory::desc primitive_md, int DNNL_ARG ){ |                          dnnl::memory& arg) { | ||||||
|                  | 
 | ||||||
|                 auto user_mem = dnnl::memory(user_md, engine, array->getBuffer()); |     auto user_mem = dnnl::memory(user_md, engine, array->getBuffer()); | ||||||
|                 const bool bReorder = primitive_md != user_mem.get_desc(); |     const bool bReorder = primitive_md != user_mem.get_desc(); | ||||||
|                 auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem; |     auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem; | ||||||
|                 if (bReorder) |     if (bReorder) | ||||||
|                     dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem); |         dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem); | ||||||
|                 args[DNNL_ARG] = mkl_mem; |     arg = mkl_mem; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////
 | ||||||
| @ -95,7 +97,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output, | |||||||
| 
 | 
 | ||||||
|     if(rank == 4) {     // 2d
 |     if(rank == 4) {     // 2d
 | ||||||
| 
 | 
 | ||||||
|         ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |         ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|         strides   = { sH, sW }; |         strides   = { sH, sW }; | ||||||
|         kernel    = { kH, kW }; |         kernel    = { kH, kW }; | ||||||
| @ -108,7 +110,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output, | |||||||
|     } |     } | ||||||
|     else {              // 3d
 |     else {              // 3d
 | ||||||
| 
 | 
 | ||||||
|         ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); |         ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); | ||||||
| 
 | 
 | ||||||
|         strides   = { sD, sH, sW }; |         strides   = { sD, sH, sW }; | ||||||
|         kernel    = { kD, kH, kW }; |         kernel    = { kD, kH, kW }; | ||||||
| @ -162,7 +164,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output, | |||||||
|     // provide memory buffers and check whether reorder is required
 |     // provide memory buffers and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|     // input
 |     // input
 | ||||||
|     mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); |     mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|     // output
 |     // output
 | ||||||
|     auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer()); |     auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer()); | ||||||
| @ -199,7 +201,7 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, | |||||||
| 
 | 
 | ||||||
|     if(rank == 4) {     // 2d
 |     if(rank == 4) {     // 2d
 | ||||||
| 
 | 
 | ||||||
|         ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |         ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|         strides   = { sH, sW }; |         strides   = { sH, sW }; | ||||||
|         kernel    = { kH, kW }; |         kernel    = { kH, kW }; | ||||||
| @ -212,7 +214,7 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, | |||||||
|     } |     } | ||||||
|     else {              // 3d
 |     else {              // 3d
 | ||||||
| 
 | 
 | ||||||
|         ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); |         ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); | ||||||
| 
 | 
 | ||||||
|         strides   = { sD, sH, sW }; |         strides   = { sD, sH, sW }; | ||||||
|         kernel    = { kD, kH, kW }; |         kernel    = { kD, kH, kW }; | ||||||
| @ -280,8 +282,8 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, | |||||||
|     std::unordered_map<int, dnnl::memory> args; |     std::unordered_map<int, dnnl::memory> args; | ||||||
| 
 | 
 | ||||||
|     // gradO
 |     // gradO
 | ||||||
|     mkldnnUtils::loadDataToMklStream(gradO, engine, stream, args, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); |     mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); | ||||||
|      | 
 | ||||||
|     // gradI
 |     // gradI
 | ||||||
|     auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); |     auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); | ||||||
|     const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); |     const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); | ||||||
| @ -291,8 +293,8 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, | |||||||
|     if(mode == algorithm::pooling_max) { |     if(mode == algorithm::pooling_max) { | ||||||
| 
 | 
 | ||||||
|         // input
 |         // input
 | ||||||
|         mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC); |         mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
|          | 
 | ||||||
|         // z
 |         // z
 | ||||||
|         auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine); |         auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine); | ||||||
|         args[DNNL_ARG_DST] = z_mkl_mem; |         args[DNNL_ARG_DST] = z_mkl_mem; | ||||||
|  | |||||||
| @ -131,7 +131,7 @@ namespace sd { | |||||||
|          * @param reference to memory descriptor |          * @param reference to memory descriptor | ||||||
|          * @return memory format |          * @return memory format | ||||||
|          */ |          */ | ||||||
|         void   setBlockStrides(const NDArray* array, const int rank, dnnl::memory::desc& mklMd); |         void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd); | ||||||
|         //////////////////////////////////////////////////////////////////////
 |         //////////////////////////////////////////////////////////////////////
 | ||||||
|         /**
 |         /**
 | ||||||
|         * This function load and reorder user memory to mkl |         * This function load and reorder user memory to mkl | ||||||
| @ -143,8 +143,8 @@ namespace sd { | |||||||
|         * @param primitive memory descriptor |         * @param primitive memory descriptor | ||||||
|         * @param dnnl arg activation enumerator |         * @param dnnl arg activation enumerator | ||||||
|         */ |         */ | ||||||
|         void loadDataToMklStream(const NDArray* array, dnnl::engine& engine, dnnl::stream& stream, |         void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, | ||||||
|              std::unordered_map<int, dnnl::memory>& args, dnnl::memory::desc& user_md, dnnl::memory::desc primitive_md, int DNNL_ARG); |                                 dnnl::memory& arg); | ||||||
| 
 | 
 | ||||||
|         /**
 |         /**
 | ||||||
|          * Utility methods for MKLDNN |          * Utility methods for MKLDNN | ||||||
|  | |||||||
| @ -55,12 +55,12 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
|                 dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format); |                 dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format); | ||||||
|                 dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); |                 dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); | ||||||
|                 mkldnnUtils::setBlockStrides(x, xRank, x_user_md); |                 mkldnnUtils::setBlockStrides(x, x_user_md); | ||||||
| 
 | 
 | ||||||
|                 // z
 |                 // z
 | ||||||
|                 dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format); |                 dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format); | ||||||
|                 dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format); |                 dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format); | ||||||
|                 mkldnnUtils::setBlockStrides(z, xRank, z_user_md); |                 mkldnnUtils::setBlockStrides(z, z_user_md); | ||||||
| 
 | 
 | ||||||
|                 auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); |                 auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||||
| 
 | 
 | ||||||
| @ -80,7 +80,7 @@ namespace sd { | |||||||
|                 // provide memory buffers and check whether reorder is required
 |                 // provide memory buffers and check whether reorder is required
 | ||||||
| 
 | 
 | ||||||
|                 // input
 |                 // input
 | ||||||
|                 mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); |                 mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|                 // z
 |                 // z
 | ||||||
|                 auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); |                 auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); | ||||||
| @ -156,19 +156,19 @@ namespace sd { | |||||||
|                 // x
 |                 // x
 | ||||||
|                 dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); | ||||||
|                 dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); | ||||||
|                 mkldnnUtils::setBlockStrides(x, xRank, x_user_md); |                 mkldnnUtils::setBlockStrides(x, x_user_md); | ||||||
| 
 | 
 | ||||||
|                 // dLdx
 |                 // dLdx
 | ||||||
|                 dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); | ||||||
|                 dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); | ||||||
|                 mkldnnUtils::setBlockStrides(dLdx, xRank, dLdx_user_md); |                 mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); | ||||||
|                 // todo if mkl does not support broadcast we can remove this
 |                 // todo if mkl does not support broadcast we can remove this
 | ||||||
|                 format = mkldnnUtils::getFormat(dLdzRank); |                 format = mkldnnUtils::getFormat(dLdzRank); | ||||||
| 
 | 
 | ||||||
|                 // dLdz
 |                 // dLdz
 | ||||||
|                 dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); | ||||||
|                 dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); | ||||||
|                 mkldnnUtils::setBlockStrides(dLdz, dLdzRank, dLdz_user_md); |                 mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); | ||||||
| 
 | 
 | ||||||
|                 auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); |                 auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||||
| 
 | 
 | ||||||
| @ -188,7 +188,7 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
|                 // provide memory buffers and check whether reorder is required for forward
 |                 // provide memory buffers and check whether reorder is required for forward
 | ||||||
|                 // input
 |                 // input
 | ||||||
|                 mkldnnUtils::loadDataToMklStream(x, engine, stream, argsff, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC); |                 mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|                 // dLdx
 |                 // dLdx
 | ||||||
|                 auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer()); |                 auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer()); | ||||||
| @ -200,7 +200,7 @@ namespace sd { | |||||||
|                 argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; |                 argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; | ||||||
|                 argsbp[DNNL_ARG_DST] = dLdx_mkl_mem; |                 argsbp[DNNL_ARG_DST] = dLdx_mkl_mem; | ||||||
|                 // dLdz
 |                 // dLdz
 | ||||||
|                 mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, argsbp, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); |                 mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), argsbp[DNNL_ARG_DIFF_DST]); | ||||||
| 
 | 
 | ||||||
|                 // run calculations forward
 |                 // run calculations forward
 | ||||||
|                 dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff); |                 dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff); | ||||||
|  | |||||||
| @ -44,12 +44,12 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
|                 dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); | ||||||
|                 dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); | ||||||
|                 mkldnnUtils::setBlockStrides(x, xRank, x_user_md); |                 mkldnnUtils::setBlockStrides(x, x_user_md); | ||||||
| 
 | 
 | ||||||
|                 // z
 |                 // z
 | ||||||
|                 dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); | ||||||
|                 dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); | ||||||
|                 mkldnnUtils::setBlockStrides(z, xRank, z_user_md); |                 mkldnnUtils::setBlockStrides(z, z_user_md); | ||||||
| 
 | 
 | ||||||
|                 auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); |                 auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||||
| 
 | 
 | ||||||
| @ -68,7 +68,7 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
|                 // provide memory buffers and check whether reorder is required
 |                 // provide memory buffers and check whether reorder is required
 | ||||||
|                 // input
 |                 // input
 | ||||||
|                 mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); |                 mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|                 // z
 |                 // z
 | ||||||
|                 auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); |                 auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); | ||||||
| @ -132,17 +132,17 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
|                 dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); | ||||||
|                 dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); | ||||||
|                 mkldnnUtils::setBlockStrides(x, xRank, x_user_md); |                 mkldnnUtils::setBlockStrides(x, x_user_md); | ||||||
| 
 | 
 | ||||||
|                 // dLdz
 |                 // dLdz
 | ||||||
|                 dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); | ||||||
|                 dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); | ||||||
|                 mkldnnUtils::setBlockStrides(dLdz, xRank, dLdz_user_md); |                 mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); | ||||||
|    | 
 | ||||||
|                 // dLdx
 |                 // dLdx
 | ||||||
|                 dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); | ||||||
|                 dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); |                 dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); | ||||||
|                 mkldnnUtils::setBlockStrides(dLdx, xRank, dLdx_user_md); |                 mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); | ||||||
| 
 | 
 | ||||||
|                 auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); |                 auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||||
| 
 | 
 | ||||||
| @ -162,10 +162,10 @@ namespace sd { | |||||||
| 
 | 
 | ||||||
|                 // provide memory buffers and check whether reorder is required for forward
 |                 // provide memory buffers and check whether reorder is required for forward
 | ||||||
|                 // input
 |                 // input
 | ||||||
|                 mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); |                 mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); | ||||||
| 
 | 
 | ||||||
|                 // dLdz
 |                 // dLdz
 | ||||||
|                 mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, args, dLdz_user_md, op_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); |                 mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); | ||||||
| 
 | 
 | ||||||
|                 // dLdx
 |                 // dLdx
 | ||||||
|                 auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer()); |                 auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer()); | ||||||
|  | |||||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user