Shyrma cudnn (#192)
* - implementation of cudnn batchnorm_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in batchnorm_bp based on cudnn api Signed-off-by: Yurii <iuriish@yahoo.com> * - move pooling mkl code and delete some unnecessary files Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing cudnn pooling2d ops (avg/max, ff/bp) Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing cudnn pooling 3d (ff/bp) ops Signed-off-by: Yurii <iuriish@yahoo.com> * - provide ff step in case of cudnn maxpool3d_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - remove half type from set of supported types in mkl dpethwise conv op Signed-off-by: Yurii <iuriish@yahoo.com> * - bring back cudaStreamSynchronize in batchnorm and pooling cudnn ops Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
2f08af3166
commit
7a7ee4b021
|
@ -197,8 +197,7 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
|
|||
// ***** calculations ***** //
|
||||
|
||||
// notations:
|
||||
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output
|
||||
// g = dLdO
|
||||
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output, g = dLdO
|
||||
// stdInv = 1 / (v + eps)^0.5
|
||||
// N - batch size (product of spatial dimensions)
|
||||
|
||||
|
|
|
@ -31,31 +31,28 @@ namespace ops {
|
|||
CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", input->rankOf());
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
auto argI = *(block.getIArguments());
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const auto kH = INT_ARG(0);
|
||||
const auto kW = INT_ARG(1);
|
||||
const auto sH = INT_ARG(2);
|
||||
const auto sW = INT_ARG(3);
|
||||
int pH = INT_ARG(4);
|
||||
int pW = INT_ARG(5);
|
||||
auto pH = INT_ARG(4);
|
||||
auto pW = INT_ARG(5);
|
||||
const auto dH = INT_ARG(6);
|
||||
const auto dW = INT_ARG(7);
|
||||
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
||||
const auto extraParam0 = INT_ARG(9);
|
||||
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
int oH = 0;
|
||||
int oW = 0;
|
||||
|
||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||
|
||||
|
@ -207,7 +204,6 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
|
|||
}
|
||||
|
||||
return Status::OK();
|
||||
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(avgpool2d_bp) {
|
||||
|
|
|
@ -51,14 +51,14 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
|
|||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
|
||||
if(!isNCDHW) {
|
||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
|
@ -176,8 +176,8 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
|
|||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(!isNCDHW) {
|
||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
|
|
|
@ -32,6 +32,7 @@ namespace ops {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
// maxpool2d corresponds to poolingMode=0
|
||||
CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf());
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
|
||||
#include "cudnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(avgpool2d, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
const auto kH = INT_ARG(0);
|
||||
const auto kW = INT_ARG(1);
|
||||
const auto sH = INT_ARG(2);
|
||||
const auto sW = INT_ARG(3);
|
||||
auto pH = INT_ARG(4);
|
||||
auto pW = INT_ARG(5);
|
||||
const auto dH = INT_ARG(6);
|
||||
const auto dW = INT_ARG(7);
|
||||
const auto paddingMode = static_cast<bool>(INT_ARG(8));
|
||||
const auto extraParam0 = INT_ARG(9);
|
||||
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D CUDNN op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
int oH = 0;
|
||||
int oW = 0;
|
||||
|
||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||
|
||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
|
||||
|
||||
if (paddingMode)
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
||||
|
||||
pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, mode);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(avgpool2d, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||
|
||||
return goodType && input->dataType() == output->dataType();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(avgpool2d_bp, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
|
||||
const auto kH = INT_ARG(0); // filter(kernel) height
|
||||
const auto kW = INT_ARG(1); // filter(kernel) width
|
||||
const auto sH = INT_ARG(2); // strides height
|
||||
const auto sW = INT_ARG(3); // strides width
|
||||
auto pH = INT_ARG(4); // paddings height
|
||||
auto pW = INT_ARG(5); // paddings width
|
||||
const auto dH = INT_ARG(6); // dilations height
|
||||
const auto dW = INT_ARG(7); // dilations width
|
||||
const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
const auto extraParam0 = INT_ARG(9);
|
||||
const auto isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP CUDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL2D_BP CUDNN op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
||||
|
||||
pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, mode);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(avgpool2d_bp, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
|
||||
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||
|
||||
return goodType && (input->dataType() == gradO->dataType())
|
||||
&& (input->dataType() == gradI->dataType())
|
||||
&& shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo());
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
|
||||
#include "cudnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(avgpool3dnew, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||
|
||||
int kD = INT_ARG(0); // filter(kernel) depth
|
||||
int kH = INT_ARG(1); // filter(kernel) height
|
||||
int kW = INT_ARG(2); // filter(kernel) width
|
||||
int sD = INT_ARG(3); // strides depth
|
||||
int sH = INT_ARG(4); // strides height
|
||||
int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
int dD = INT_ARG(9); // dilations depth
|
||||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
int extraParam0 = INT_ARG(13);
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
|
||||
if(paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
||||
|
||||
pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(avgpool3dnew, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||
|
||||
return goodType && input->dataType() == output->dataType();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
|
||||
const int kD = INT_ARG(0); // filter(kernel) depth
|
||||
const int kH = INT_ARG(1); // filter(kernel) height
|
||||
const int kW = INT_ARG(2); // filter(kernel) width
|
||||
const int sD = INT_ARG(3); // strides depth
|
||||
const int sH = INT_ARG(4); // strides height
|
||||
const int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
const int dD = INT_ARG(9); // dilations depth
|
||||
const int dH = INT_ARG(10); // dilations height
|
||||
const int dW = INT_ARG(11); // dilations width
|
||||
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging
|
||||
const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
||||
|
||||
pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
|
||||
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||
|
||||
return goodType && (input->dataType() == gradO->dataType())
|
||||
&& (input->dataType() == gradI->dataType())
|
||||
&& shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo());
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -97,9 +97,6 @@ static void batchnormCUDNN(const LaunchContext* context,
|
|||
err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/beta failed", err);
|
||||
|
||||
|
||||
if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetConvolutionNdDescriptor failed", err);
|
||||
|
||||
// provide scaling parameters
|
||||
const float alpha32(1), beta32(0);
|
||||
const double alpha64(1), beta64(0);
|
||||
|
@ -114,20 +111,127 @@ static void batchnormCUDNN(const LaunchContext* context,
|
|||
x, input->getSpecialBuffer(),
|
||||
z, output->getSpecialBuffer(),
|
||||
params,
|
||||
gamma ? gamma->getSpecialBuffer(): nullptr,
|
||||
beta ? beta->getSpecialBuffer() : nullptr,
|
||||
gamma->getSpecialBuffer(), beta->getSpecialBuffer(),
|
||||
mean->getSpecialBuffer(), variance->getSpecialBuffer(), epsilon);
|
||||
|
||||
if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnBatchNormalizationForwardInference failed", err);
|
||||
|
||||
// cudaErr = cudaStreamSynchronize(*context->getCudaStream());
|
||||
// if (cudaErr != 0)
|
||||
// throw cuda_exception::build("batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr);
|
||||
|
||||
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
|
||||
if (cudaErr != 0)
|
||||
throw cuda_exception::build("batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr);
|
||||
|
||||
NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static void batchnormBpCUDNN(const LaunchContext* context,
|
||||
const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* gradO,
|
||||
NDArray* gradI, NDArray* gradG, NDArray* gradB,
|
||||
const double epsilon, const bool isSpatialMode) {
|
||||
|
||||
// input, gradO, gradI -> 4D:nchw, 5D:ncdhw
|
||||
// mean, variance, gamma, beta, gradM, gradV, gradG, gradB -> 1xCx1x1 for 4D and 1xCx1x1x1 for 5D for BATCHNORM_MODE_SPATIAL mode
|
||||
// -> 1xCxHxW for 4D and 1xCxDxHxW for 5D for BATCHNORM_MODE_PER_ACTIVATION mode
|
||||
|
||||
const cudnnDataType_t dataType = cudnnDataType(input->dataType());
|
||||
|
||||
const int xRank = input->rankOf();
|
||||
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: can't set stream for cuDNN", err);
|
||||
|
||||
const std::vector<int> xShape = input->getShapeAsVectorInt(); // input and output have same shapes
|
||||
|
||||
std::vector<int> paramsShape, paramsStrides; // mean, variance, gamma and beta have same shapes
|
||||
if(isSpatialMode) { // 1xCx1x1
|
||||
const int iC = mean->lengthOf();
|
||||
const int stride0 = mean->strideAt(0);
|
||||
paramsShape = xRank == 4 ? std::vector<int>({1, iC, 1, 1}) : std::vector<int>({1, iC, 1, 1, 1});
|
||||
paramsStrides = xRank == 4 ? std::vector<int>({iC*stride0, stride0, 1, 1}) : std::vector<int>({iC*stride0, stride0, 1, 1, 1});
|
||||
}
|
||||
else {
|
||||
paramsShape = mean->getShapeAsVectorInt();
|
||||
paramsStrides = xRank == 4 ? std::vector<int>({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3)}) : std::vector<int>({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3), (int)mean->strideAt(4)});
|
||||
}
|
||||
|
||||
std::vector<int> xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3)};
|
||||
std::vector<int> dxStrides = {(int)gradI->strideAt(0), (int)gradI->strideAt(1), (int)gradI->strideAt(2), (int)gradI->strideAt(3)};
|
||||
std::vector<int> dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3)};
|
||||
|
||||
if(xRank > 4) { // 5D
|
||||
xStrides.push_back((int)input->strideAt(4));
|
||||
dxStrides.push_back((int)gradI->strideAt(4));
|
||||
dzStrides.push_back((int)gradO->strideAt(4));
|
||||
}
|
||||
|
||||
cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW;
|
||||
|
||||
// input descriptor
|
||||
cudnnTensorDescriptor_t x;
|
||||
cudnnCreateTensorDescriptor(&x);
|
||||
if(input->ews() == 1)
|
||||
err = cudnnSetTensorNdDescriptorEx(x, format, dataType, xRank, xShape.data());
|
||||
else
|
||||
err = cudnnSetTensorNdDescriptor(x, dataType, xRank, xShape.data(), xStrides.data());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err);
|
||||
|
||||
// gradO descriptor
|
||||
cudnnTensorDescriptor_t dz;
|
||||
cudnnCreateTensorDescriptor(&dz);
|
||||
if(gradO->ews() == 1)
|
||||
err = cudnnSetTensorNdDescriptorEx(dz, format, dataType, xRank, xShape.data());
|
||||
else
|
||||
err = cudnnSetTensorNdDescriptor(dz, dataType, xRank, xShape.data(), dzStrides.data());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err);
|
||||
|
||||
// gradI descriptor
|
||||
cudnnTensorDescriptor_t dx;
|
||||
cudnnCreateTensorDescriptor(&dx);
|
||||
if(input->ews() == 1)
|
||||
err = cudnnSetTensorNdDescriptorEx(dx, format, dataType, xRank, xShape.data());
|
||||
else
|
||||
err = cudnnSetTensorNdDescriptor(dx, dataType, xRank, xShape.data(), dxStrides.data());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradI failed", err);
|
||||
|
||||
// mean, variance, gamma, gradG and gradB descriptor, the same descriptor for all of them
|
||||
cudnnTensorDescriptor_t params;
|
||||
cudnnCreateTensorDescriptor(¶ms);
|
||||
if(mean->ews() == 1)
|
||||
err = cudnnSetTensorNdDescriptorEx(params, format, dataType, xRank, paramsShape.data());
|
||||
else
|
||||
err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/gradG/gradB failed", err);
|
||||
|
||||
// provide scaling parameters
|
||||
const float alpha32(1), beta32(0);
|
||||
double alpha64(1), beta64(0);
|
||||
const void* ptrAlpha = input->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
|
||||
const void* ptrBeta = input->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
|
||||
|
||||
NDArray::prepareSpecialUse({gradI, gradG, gradB}, {input, mean, variance, gamma, gradO});
|
||||
|
||||
// calculations
|
||||
// TODO: we can use cache here
|
||||
err = cudnnBatchNormalizationBackward(*handle, isSpatialMode ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION,
|
||||
ptrAlpha, ptrBeta, ptrAlpha, ptrBeta,
|
||||
x, input->getSpecialBuffer(),
|
||||
dz, gradO->getSpecialBuffer(),
|
||||
dx, gradI->getSpecialBuffer(),
|
||||
params,
|
||||
gamma->getSpecialBuffer(), gradG->getSpecialBuffer(), gradB->getSpecialBuffer(),
|
||||
epsilon,
|
||||
nullptr/*mean->getSpecialBuffer()*/, nullptr/*variance->getSpecialBuffer()*/);
|
||||
|
||||
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnBatchNormalizationBackward failed", err);
|
||||
|
||||
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
|
||||
if (cudaErr != 0)
|
||||
throw cuda_exception::build("batchnormBpCUDNN: cudaStreamSynchronize failed !", cudaErr);
|
||||
|
||||
NDArray::registerSpecialUse({gradI, gradG, gradB}, {input, mean, variance, gamma, gradO});
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(batchnorm, ENGINE_CUDA) {
|
||||
|
@ -189,11 +293,21 @@ PLATFORM_IMPL(batchnorm, ENGINE_CUDA) {
|
|||
const bool needPermut = axes.size() == 1 && mean->lengthOf() == input->sizeAt(-1);
|
||||
|
||||
if(needPermut) { // if NHWC
|
||||
std::vector<int> perm = {0, 3, 1, 2}; // NHWC -> NCHW
|
||||
std::vector<int> perm = inRank == 4 ? std::vector<int>({0, 3, 1, 2}) : std::vector<int>({0, 4, 1, 2, 3}); // NHWC -> NCHW
|
||||
input = new NDArray(input->permute(perm));
|
||||
output = new NDArray(output->permute(perm));
|
||||
}
|
||||
|
||||
// cudnn requires gamma and beta to be non-nullptr
|
||||
if(!applyScale) {
|
||||
gamma = new NDArray(mean);
|
||||
*gamma = 1;
|
||||
}
|
||||
if(!applyOffset) {
|
||||
beta = new NDArray(mean);
|
||||
*beta = 0;
|
||||
}
|
||||
|
||||
// calculations
|
||||
batchnormCUDNN(block.launchContext(), input, mean, variance, gamma, beta, output, epsilon, axes.size() == 1);
|
||||
|
||||
|
@ -202,6 +316,12 @@ PLATFORM_IMPL(batchnorm, ENGINE_CUDA) {
|
|||
delete output;
|
||||
}
|
||||
|
||||
if(!applyScale)
|
||||
delete gamma;
|
||||
|
||||
if(!applyOffset)
|
||||
delete beta;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -220,9 +340,6 @@ PLATFORM_CHECK(batchnorm, ENGINE_CUDA) {
|
|||
const int numOfIntArgs = block.getIArguments()->size();
|
||||
const int xRank = input->rankOf();
|
||||
|
||||
// disable cudnn batchnorm so far
|
||||
return false;
|
||||
|
||||
// *********************************** //
|
||||
if(xRank != 4 && xRank != 5)
|
||||
return false;
|
||||
|
@ -269,6 +386,182 @@ PLATFORM_CHECK(batchnorm, ENGINE_CUDA) {
|
|||
return true;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(batchnorm_bp, ENGINE_CUDA) {
|
||||
|
||||
NDArray* input = INPUT_VARIABLE(0);
|
||||
NDArray* mean = INPUT_VARIABLE(1);
|
||||
NDArray* variance = INPUT_VARIABLE(2);
|
||||
NDArray* gamma = nullptr;
|
||||
NDArray* beta = nullptr;
|
||||
NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon
|
||||
|
||||
NDArray* gradI = OUTPUT_VARIABLE(0);
|
||||
NDArray* gradM = OUTPUT_VARIABLE(1);
|
||||
NDArray* gradV = OUTPUT_VARIABLE(2);
|
||||
NDArray* gradG = nullptr;
|
||||
NDArray* gradB = nullptr;
|
||||
|
||||
const bool applyScale = (bool)INT_ARG(0);
|
||||
const bool applyOffset = (bool)INT_ARG(1);
|
||||
const float epsilon = T_ARG(0);
|
||||
|
||||
if(applyScale) {
|
||||
gamma = INPUT_VARIABLE(3);
|
||||
gradG = OUTPUT_VARIABLE(3);
|
||||
}
|
||||
if(applyOffset) {
|
||||
beta = INPUT_VARIABLE(3 + (int)applyScale);
|
||||
gradB = OUTPUT_VARIABLE(3 + (int)applyScale);
|
||||
}
|
||||
|
||||
const int numOfIntArgs = block.getIArguments()->size();
|
||||
const int inRank = input->rankOf();
|
||||
|
||||
// get axes args to normalize input array over
|
||||
std::vector<int> axes;
|
||||
if(numOfIntArgs > 2)
|
||||
for(int i = 2; i < numOfIntArgs; ++i)
|
||||
axes.push_back(INT_ARG(i));
|
||||
else
|
||||
axes.push_back(inRank-1); // default dimension to reduce along is last dimension
|
||||
|
||||
const int numOfAxes = axes.size();
|
||||
REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_BP CUDNN op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank);
|
||||
|
||||
// evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes
|
||||
// for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5}
|
||||
std::vector<Nd4jLong> expShape;
|
||||
if(numOfAxes == 1)
|
||||
expShape.push_back(input->sizeAt(axes[0]));
|
||||
else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3}
|
||||
expShape = std::vector<Nd4jLong>(inRank, 1);
|
||||
for(uint i = 0; i < numOfAxes; ++i)
|
||||
expShape[axes[i]] = input->sizeAt(axes[i]);
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(mean->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str());
|
||||
REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str());
|
||||
if(gamma)
|
||||
REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str());
|
||||
if(beta)
|
||||
REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str());
|
||||
|
||||
REQUIRE_TRUE(input->isSameShape(gradO), 0, "BATCHNORM_BP CUDNN op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
|
||||
// types of all input arrays should be the same (except gradO)
|
||||
for(int i = 1; i < block.width() - 2; ++i)
|
||||
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP CUDNN op: types of arrays (input, mean, variance, gamma, beta) should be the same !");
|
||||
|
||||
// cudnn supports NCHW format only
|
||||
const bool needPermut = axes.size() == 1 && mean->lengthOf() != input->sizeAt(1);
|
||||
|
||||
if(needPermut) { // if NHWC
|
||||
std::vector<int> perm = inRank == 4 ? std::vector<int>({0, 3, 1, 2}) : std::vector<int>({0, 4, 1, 2, 3}); // NHWC -> NCHW
|
||||
input = new NDArray(input->permute(perm));
|
||||
gradO = new NDArray(gradO->permute(perm));
|
||||
gradI = new NDArray(gradI->permute(perm));
|
||||
}
|
||||
|
||||
// cudnn requires gamma, gradG, gradB to be non-nullptr
|
||||
if(!applyScale) {
|
||||
gamma = new NDArray(mean);
|
||||
gradG = new NDArray(mean);
|
||||
*gamma = 1;
|
||||
}
|
||||
if(!applyOffset)
|
||||
gradB = new NDArray(mean);
|
||||
|
||||
// calculations
|
||||
batchnormBpCUDNN(block.launchContext(), input, mean, variance, gamma, gradO, gradI, gradG, gradB, epsilon, axes.size() == 1);
|
||||
|
||||
*gradM = 0; // put zeros so far
|
||||
*gradV = 0; // put zeros so far
|
||||
|
||||
if(needPermut) {
|
||||
delete input;
|
||||
delete gradO;
|
||||
delete gradI;
|
||||
}
|
||||
|
||||
if(!applyScale) {
|
||||
delete gamma;
|
||||
delete gradG;
|
||||
}
|
||||
|
||||
if(!applyOffset)
|
||||
delete gradB;
|
||||
|
||||
return Status::OK();
|
||||
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(batchnorm_bp, ENGINE_CUDA) {
|
||||
|
||||
NDArray* input = INPUT_VARIABLE(0);
|
||||
NDArray* mean = INPUT_VARIABLE(1);
|
||||
NDArray* variance = INPUT_VARIABLE(2);
|
||||
NDArray* gamma = nullptr;
|
||||
NDArray* beta = nullptr;
|
||||
NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon
|
||||
|
||||
NDArray* gradI = OUTPUT_VARIABLE(0);
|
||||
NDArray* gradM = OUTPUT_VARIABLE(1);
|
||||
NDArray* gradV = OUTPUT_VARIABLE(2);
|
||||
NDArray* gradG = nullptr;
|
||||
NDArray* gradB = nullptr;
|
||||
|
||||
const int numOfIntArgs = block.getIArguments()->size();
|
||||
const int xRank = input->rankOf();
|
||||
|
||||
// *********************************** //
|
||||
if(xRank != 4 && xRank != 5)
|
||||
return false;
|
||||
|
||||
// *********************************** //
|
||||
const bool badType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF;
|
||||
if(badType)
|
||||
return false;
|
||||
|
||||
// *********************************** //
|
||||
// get axes args to normalize input array over
|
||||
std::vector<int> axes;
|
||||
if(numOfIntArgs > 2)
|
||||
for(int i = 2; i < numOfIntArgs; ++i)
|
||||
axes.push_back(INT_ARG(i));
|
||||
else
|
||||
axes.push_back(xRank-1); // default dimension to reduce along is last dimension
|
||||
|
||||
if(axes.size() != 1 && axes.size() != 3 && axes.size() != 4)
|
||||
return false;
|
||||
|
||||
// *********************************** //
|
||||
bool allParamsHaveSameShapeAndStrides = shape::haveSameShapeAndStrides(mean->getShapeInfo(), variance->getShapeInfo());
|
||||
if(gamma)
|
||||
allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gamma->getShapeInfo());
|
||||
if(gradG)
|
||||
allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gradG->getShapeInfo());
|
||||
if(gradB)
|
||||
allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gradB->getShapeInfo());
|
||||
|
||||
if(!allParamsHaveSameShapeAndStrides)
|
||||
return false;
|
||||
|
||||
// *********************************** //
|
||||
bool isFormatGood = false;
|
||||
if(axes.size() == 1)
|
||||
isFormatGood = mean->lengthOf() == input->sizeAt(1) || mean->lengthOf() == input->sizeAt(-1); // mean [C]
|
||||
else {
|
||||
auto inputShapeModif = input->getShapeAsVector(); // [dim0,dim1,dim2,dim3] 4D or [dim0,dim1,dim2,dim3,dim4]
|
||||
inputShapeModif[0] = 1;
|
||||
isFormatGood = mean->isSameShape(inputShapeModif); // mean [1,dim1,dim2,dim3] 4D or [1,dim1,dim2,dim3,dim4]
|
||||
}
|
||||
if(!isFormatGood)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,412 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
|
||||
#include "cudnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
|
||||
const int iH, const int iW,
|
||||
const int oH, const int oW,
|
||||
const int kH, const int kW,
|
||||
const int sH, const int sW,
|
||||
const int pH, const int pW,
|
||||
const int dH, const int dW,
|
||||
const bool isNCHW) {
|
||||
|
||||
const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH);
|
||||
const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW);
|
||||
|
||||
const bool isPHasymm = pH != (pHsum - pH);
|
||||
const bool isPWasymm = pW != (pWsum - pW);
|
||||
|
||||
if(!isPHasymm && !isPWasymm)
|
||||
return;
|
||||
|
||||
std::vector<Nd4jLong> newShape = input->getShapeAsVector();
|
||||
|
||||
const int iHposition = isNCHW ? 2 : 1;
|
||||
|
||||
if(isPHasymm)
|
||||
newShape[iHposition] += 1;
|
||||
if(isPWasymm)
|
||||
newShape[iHposition + 1] += 1;
|
||||
|
||||
NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext());
|
||||
|
||||
if(isNCHW)
|
||||
(*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3)}).assign(input);
|
||||
else
|
||||
(*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,0}).assign(input);
|
||||
|
||||
input = newInput;
|
||||
|
||||
if(gradI != nullptr)
|
||||
gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext());
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
|
||||
const int iD, const int iH, const int iW,
|
||||
const int oD, const int oH, const int oW,
|
||||
const int kD, const int kH, const int kW,
|
||||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW,
|
||||
const bool isNCDHW) {
|
||||
|
||||
const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD);
|
||||
const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH);
|
||||
const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW);
|
||||
|
||||
const bool isPDasymm = pD != (pDsum - pD);
|
||||
const bool isPHasymm = pH != (pHsum - pH);
|
||||
const bool isPWasymm = pW != (pWsum - pW);
|
||||
|
||||
if(!isPDasymm && !isPHasymm && !isPWasymm)
|
||||
return;
|
||||
|
||||
std::vector<Nd4jLong> newShape = input->getShapeAsVector();
|
||||
|
||||
const int iDposition = isNCDHW ? 2 : 1;
|
||||
|
||||
if(isPDasymm)
|
||||
newShape[iDposition] += 1;
|
||||
if(isPHasymm)
|
||||
newShape[iDposition + 1] += 1;
|
||||
if(isPWasymm)
|
||||
newShape[iDposition + 2] += 1;
|
||||
|
||||
NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext());
|
||||
|
||||
if(isNCDHW)
|
||||
(*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3), 0,input->sizeAt(4)}).assign(input);
|
||||
else
|
||||
(*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,input->sizeAt(3), 0,0}).assign(input);
|
||||
|
||||
input = newInput;
|
||||
|
||||
if(gradI != nullptr)
|
||||
gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext());
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void pooling2dCUDNN(const LaunchContext* context,
|
||||
const NDArray* input, NDArray* output,
|
||||
const int kH, const int kW,
|
||||
const int sH, const int sW,
|
||||
const int pH, const int pW,
|
||||
const int dH, const int dW,
|
||||
const bool isNCHW, const cudnnPoolingMode_t mode) {
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: can't set stream for cuDNN", err);
|
||||
|
||||
cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||
|
||||
// input descriptor
|
||||
cudnnTensorDescriptor_t x;
|
||||
cudnnCreateTensorDescriptor(&x);
|
||||
if(input->ews() == 1)
|
||||
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err);
|
||||
|
||||
// output descriptor
|
||||
cudnnTensorDescriptor_t z;
|
||||
cudnnCreateTensorDescriptor(&z);
|
||||
if(output->ews() == 1)
|
||||
err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1));
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output failed", err);
|
||||
|
||||
// description of pooling
|
||||
cudnnPoolingDescriptor_t pooling;
|
||||
cudnnCreatePoolingDescriptor(&pooling);
|
||||
err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, pH, pW, sH, sW);
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetPooling2dDescriptor failed", err);
|
||||
|
||||
// provide scaling parameters
|
||||
const float alpha32(1), beta32(0);
|
||||
const double alpha64(1), beta64(0);
|
||||
const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
|
||||
const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
|
||||
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
|
||||
// run calculation
|
||||
err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, z, output->specialBuffer());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnPoolingForward failed", err);
|
||||
|
||||
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
|
||||
if (cudaErr != 0)
|
||||
throw cuda_exception::build("pooling2dCUDNN: cudaStreamSynchronize failed !", cudaErr);
|
||||
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void pooling2dBpCUDNN(const LaunchContext* context,
|
||||
const NDArray* input, const NDArray* gradO,
|
||||
NDArray* gradI,
|
||||
const int kH, const int kW,
|
||||
const int sH, const int sW,
|
||||
const int pH, const int pW,
|
||||
const int dH, const int dW,
|
||||
const bool isNCHW, const cudnnPoolingMode_t mode) {
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: can't set stream for cuDNN", err);
|
||||
|
||||
cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||
|
||||
// input and gradI descriptor
|
||||
cudnnTensorDescriptor_t x;
|
||||
cudnnCreateTensorDescriptor(&x);
|
||||
if(input->ews() == 1)
|
||||
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input/gradI failed", err);
|
||||
|
||||
// gradO descriptor
|
||||
cudnnTensorDescriptor_t dz;
|
||||
cudnnCreateTensorDescriptor(&dz);
|
||||
if(gradO->ews() == 1)
|
||||
err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW);
|
||||
else
|
||||
err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1));
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO failed", err);
|
||||
|
||||
// description of pooling
|
||||
cudnnPoolingDescriptor_t pooling;
|
||||
cudnnCreatePoolingDescriptor(&pooling);
|
||||
err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, pH, pW, sH, sW);
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetPooling2dDescriptor failed", err);
|
||||
|
||||
// provide scaling parameters
|
||||
const float alpha32(1), beta32(0);
|
||||
const double alpha64(1), beta64(0);
|
||||
const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
|
||||
const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
|
||||
|
||||
NDArray::prepareSpecialUse({gradI}, {input, gradO});
|
||||
|
||||
// run calculation for gradI
|
||||
err = cudnnPoolingBackward(*handle, pooling, alpha, dz, gradO->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err);
|
||||
|
||||
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
|
||||
if (cudaErr != 0)
|
||||
throw cuda_exception::build("pooling2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr);
|
||||
|
||||
NDArray::registerSpecialUse({gradI}, {input, gradO});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void pooling3dCUDNN(const LaunchContext* context,
|
||||
const NDArray* input, NDArray* output,
|
||||
const int kD, const int kH, const int kW,
|
||||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW,
|
||||
const bool isNCDHW, const cudnnPoolingMode_t mode) {
|
||||
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: can't set stream for cuDNN", err);
|
||||
printf("fffffffffff\n");
|
||||
const int numDims = 5;
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
const int pSizes[] = {pD, pH, pW};
|
||||
const int sSizes[] = {sD, sH, sW};
|
||||
const int kSizes[] = {kD, kH, kW};
|
||||
|
||||
const int xShape[] = {bS, iC, iD, iH, iW};
|
||||
const int zShape[] = {bS, oC, oD, oH, oW};
|
||||
|
||||
const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)};
|
||||
const int zStrides[] = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)};
|
||||
|
||||
cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||
|
||||
// input descriptor
|
||||
cudnnTensorDescriptor_t x;
|
||||
cudnnCreateTensorDescriptor(&x);
|
||||
if(input->ews() == 1)
|
||||
err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape);
|
||||
else
|
||||
err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides);
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err);
|
||||
|
||||
// output descriptor
|
||||
cudnnTensorDescriptor_t z;
|
||||
cudnnCreateTensorDescriptor(&z);
|
||||
if(output->ews() == 1)
|
||||
err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape);
|
||||
else
|
||||
err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape, zStrides);
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output failed", err);
|
||||
|
||||
// description of pooling
|
||||
cudnnPoolingDescriptor_t pooling;
|
||||
cudnnCreatePoolingDescriptor(&pooling);
|
||||
err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, numDims - 2, kSizes, pSizes, sSizes);
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetPoolingNdDescriptor failed", err);
|
||||
|
||||
// provide scaling parameters
|
||||
const float alpha32(1), beta32(0);
|
||||
const double alpha64(1), beta64(0);
|
||||
const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
|
||||
const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
|
||||
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
|
||||
// run calculation
|
||||
err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, z, output->specialBuffer());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnPoolingForward failed", err);
|
||||
|
||||
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
|
||||
if (cudaErr != 0)
|
||||
throw cuda_exception::build("pooling3dCUDNN: cudaStreamSynchronize failed !", cudaErr);
|
||||
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void pooling3dBpCUDNN(const LaunchContext* context,
|
||||
const NDArray* input, const NDArray* gradO,
|
||||
NDArray* gradI,
|
||||
const int kD, const int kH, const int kW,
|
||||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW,
|
||||
const bool isNCDHW, const cudnnPoolingMode_t mode) {
|
||||
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: can't set stream for cuDNN", err);
|
||||
|
||||
const int numDims = 5;
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
const int pSizes[] = {pD, pH, pW};
|
||||
const int sSizes[] = {sD, sH, sW};
|
||||
const int kSizes[] = {kD, kH, kW};
|
||||
|
||||
const int xShape[] = {bS, iC, iD, iH, iW};
|
||||
const int dzShape[] = {bS, oC, oD, oH, oW};
|
||||
|
||||
const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)};
|
||||
const int dzStrides[] = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)};
|
||||
|
||||
cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||
|
||||
// input and gradI descriptor
|
||||
cudnnTensorDescriptor_t x;
|
||||
cudnnCreateTensorDescriptor(&x);
|
||||
if(input->ews() == 1)
|
||||
err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape);
|
||||
else
|
||||
err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides);
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input/gradI failed", err);
|
||||
|
||||
// gradO descriptor
|
||||
cudnnTensorDescriptor_t dz;
|
||||
cudnnCreateTensorDescriptor(&dz);
|
||||
if(gradO->ews() == 1)
|
||||
err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape);
|
||||
else
|
||||
err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape, dzStrides);
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err);
|
||||
|
||||
// description of pooling
|
||||
cudnnPoolingDescriptor_t pooling;
|
||||
cudnnCreatePoolingDescriptor(&pooling);
|
||||
err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, numDims - 2, kSizes, pSizes, sSizes);
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetPoolingNdDescriptor failed", err);
|
||||
|
||||
// provide scaling parameters
|
||||
const float alpha32(1), beta32(0);
|
||||
const double alpha64(1), beta64(0);
|
||||
const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
|
||||
const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
|
||||
|
||||
// cudnn maxpool2d_bp api requires ff output as one of input arguments
|
||||
if(mode == CUDNN_POOLING_MAX) {
|
||||
|
||||
NDArray temp(gradO);
|
||||
|
||||
NDArray::prepareSpecialUse({gradI}, {input, gradO, &temp});
|
||||
|
||||
// run ff calculation
|
||||
err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, dz, temp.specialBuffer());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnPoolingForward failed", err);
|
||||
|
||||
// run bp calculation for gradI
|
||||
err = cudnnPoolingBackward(*handle, pooling, alpha, dz, temp.getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err);
|
||||
|
||||
NDArray::registerSpecialUse({gradI}, {input, gradO, &temp});
|
||||
}
|
||||
else {
|
||||
|
||||
NDArray::prepareSpecialUse({gradI}, {input, gradO});
|
||||
|
||||
// run bp calculation for gradI
|
||||
err = cudnnPoolingBackward(*handle, pooling, alpha, dz, gradO->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer());
|
||||
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err);
|
||||
|
||||
NDArray::registerSpecialUse({gradI}, {input, gradO});
|
||||
}
|
||||
|
||||
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
|
||||
if (cudaErr != 0)
|
||||
throw cuda_exception::build("pooling3dBpCUDNN: cudaStreamSynchronize failed !", cudaErr);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -30,8 +30,8 @@
|
|||
|
||||
#include <cudnn.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
DECLARE_PLATFORM(conv2d, ENGINE_CUDA);
|
||||
|
@ -46,6 +46,18 @@ namespace platforms {
|
|||
DECLARE_PLATFORM(batchnorm, ENGINE_CUDA);
|
||||
DECLARE_PLATFORM(batchnorm_bp, ENGINE_CUDA);
|
||||
|
||||
DECLARE_PLATFORM(avgpool2d, ENGINE_CUDA);
|
||||
DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CUDA);
|
||||
|
||||
DECLARE_PLATFORM(maxpool2d, ENGINE_CUDA);
|
||||
DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CUDA);
|
||||
|
||||
DECLARE_PLATFORM(avgpool3dnew, ENGINE_CUDA);
|
||||
DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CUDA);
|
||||
|
||||
DECLARE_PLATFORM(maxpool3dnew, ENGINE_CUDA);
|
||||
DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CUDA);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
FORCEINLINE cudnnDataType_t cudnnDataType(nd4j::DataType dataType) {
|
||||
switch (dataType) {
|
||||
|
@ -65,91 +77,62 @@ FORCEINLINE cudnnDataType_t cudnnDataType(nd4j::DataType dataType) {
|
|||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
FORCEINLINE void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
|
||||
const int iH, const int iW,
|
||||
const int oH, const int oW,
|
||||
const int kH, const int kW,
|
||||
const int sH, const int sW,
|
||||
const int pH, const int pW,
|
||||
const int dH, const int dW,
|
||||
const bool isNCHW) {
|
||||
|
||||
const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH);
|
||||
const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW);
|
||||
|
||||
const bool isPHasymm = pH != (pHsum - pH);
|
||||
const bool isPWasymm = pW != (pWsum - pW);
|
||||
|
||||
if(!isPHasymm && !isPWasymm)
|
||||
return;
|
||||
|
||||
std::vector<Nd4jLong> newShape = input->getShapeAsVector();
|
||||
|
||||
const int iHposition = isNCHW ? 2 : 1;
|
||||
|
||||
if(isPHasymm)
|
||||
newShape[iHposition] += 1;
|
||||
if(isPWasymm)
|
||||
newShape[iHposition + 1] += 1;
|
||||
|
||||
NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext());
|
||||
|
||||
if(isNCHW)
|
||||
(*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3)}).assign(input);
|
||||
else
|
||||
(*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,0}).assign(input);
|
||||
|
||||
input = newInput;
|
||||
|
||||
if(gradI != nullptr)
|
||||
gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext());
|
||||
}
|
||||
|
||||
void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
|
||||
const int iH, const int iW,
|
||||
const int oH, const int oW,
|
||||
const int kH, const int kW,
|
||||
const int sH, const int sW,
|
||||
const int pH, const int pW,
|
||||
const int dH, const int dW,
|
||||
const bool isNCHW);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
FORCEINLINE void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
|
||||
const int iD, const int iH, const int iW,
|
||||
const int oD, const int oH, const int oW,
|
||||
const int kD, const int kH, const int kW,
|
||||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW,
|
||||
const bool isNCDHW) {
|
||||
void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
|
||||
const int iD, const int iH, const int iW,
|
||||
const int oD, const int oH, const int oW,
|
||||
const int kD, const int kH, const int kW,
|
||||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW,
|
||||
const bool isNCDHW);
|
||||
|
||||
const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD);
|
||||
const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH);
|
||||
const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW);
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void pooling2dCUDNN(const LaunchContext* context,
|
||||
const NDArray* input, NDArray* output,
|
||||
const int kH, const int kW,
|
||||
const int sH, const int sW,
|
||||
const int pH, const int pW,
|
||||
const int dH, const int dW,
|
||||
const bool isNCHW, const cudnnPoolingMode_t mode);
|
||||
|
||||
const bool isPDasymm = pD != (pDsum - pD);
|
||||
const bool isPHasymm = pH != (pHsum - pH);
|
||||
const bool isPWasymm = pW != (pWsum - pW);
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void pooling2dBpCUDNN(const LaunchContext* context,
|
||||
const NDArray* input, const NDArray* gradO,
|
||||
NDArray* gradI,
|
||||
const int kH, const int kW,
|
||||
const int sH, const int sW,
|
||||
const int pH, const int pW,
|
||||
const int dH, const int dW,
|
||||
const bool isNCHW, const cudnnPoolingMode_t mode);
|
||||
|
||||
if(!isPDasymm && !isPHasymm && !isPWasymm)
|
||||
return;
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void pooling3dCUDNN(const LaunchContext* context,
|
||||
const NDArray* input, NDArray* output,
|
||||
const int kD, const int kH, const int kW,
|
||||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW,
|
||||
const bool isNCDHW, const cudnnPoolingMode_t mode);
|
||||
|
||||
std::vector<Nd4jLong> newShape = input->getShapeAsVector();
|
||||
|
||||
const int iDposition = isNCDHW ? 2 : 1;
|
||||
|
||||
if(isPDasymm)
|
||||
newShape[iDposition] += 1;
|
||||
if(isPHasymm)
|
||||
newShape[iDposition + 1] += 1;
|
||||
if(isPWasymm)
|
||||
newShape[iDposition + 2] += 1;
|
||||
|
||||
NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext());
|
||||
|
||||
if(isNCDHW)
|
||||
(*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3), 0,input->sizeAt(4)}).assign(input);
|
||||
else
|
||||
(*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,input->sizeAt(3), 0,0}).assign(input);
|
||||
|
||||
input = newInput;
|
||||
|
||||
if(gradI != nullptr)
|
||||
gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext());
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void pooling3dBpCUDNN(const LaunchContext* context,
|
||||
const NDArray* input, const NDArray* gradO,
|
||||
NDArray* gradI,
|
||||
const int kD, const int kH, const int kW,
|
||||
const int sD, const int sH, const int sW,
|
||||
const int pD, const int pH, const int pW,
|
||||
const int dD, const int dH, const int dW,
|
||||
const bool isNCDHW, const cudnnPoolingMode_t mode);
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
|
||||
#include "cudnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(maxpool2d, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - paddingModee;
|
||||
const auto kH = INT_ARG(0);
|
||||
const auto kW = INT_ARG(1);
|
||||
const auto sH = INT_ARG(2);
|
||||
const auto sW = INT_ARG(3);
|
||||
auto pH = INT_ARG(4);
|
||||
auto pW = INT_ARG(5);
|
||||
const auto dH = INT_ARG(6);
|
||||
const auto dW = INT_ARG(7);
|
||||
const auto paddingMode = static_cast<bool>(INT_ARG(8));
|
||||
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D CUDNN op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
int oH = 0;
|
||||
int oW = 0;
|
||||
|
||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||
|
||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
|
||||
|
||||
if (paddingMode)
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(maxpool2d, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||
|
||||
return goodType && input->dataType() == output->dataType();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(maxpool2d_bp, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
|
||||
const auto kH = INT_ARG(0); // filter(kernel) height
|
||||
const auto kW = INT_ARG(1); // filter(kernel) width
|
||||
const auto sH = INT_ARG(2); // strides height
|
||||
const auto sW = INT_ARG(3); // strides width
|
||||
auto pH = INT_ARG(4); // paddings height
|
||||
auto pW = INT_ARG(5); // paddings width
|
||||
const auto dH = INT_ARG(6); // dilations height
|
||||
const auto dW = INT_ARG(7); // dilations width
|
||||
const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
const auto isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP CUDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL2D_BP CUDNN op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(maxpool2d_bp, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
|
||||
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||
|
||||
return goodType && (input->dataType() == gradO->dataType())
|
||||
&& (input->dataType() == gradI->dataType())
|
||||
&& shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo());
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,140 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
|
||||
#include "cudnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(maxpool3dnew, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||
|
||||
int kD = INT_ARG(0); // filter(kernel) depth
|
||||
int kH = INT_ARG(1); // filter(kernel) height
|
||||
int kW = INT_ARG(2); // filter(kernel) width
|
||||
int sD = INT_ARG(3); // strides depth
|
||||
int sH = INT_ARG(4); // strides height
|
||||
int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
int dD = INT_ARG(9); // dilations depth
|
||||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
// int extraParam0 = INT_ARG(13);
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
|
||||
if(paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(maxpool3dnew, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||
|
||||
return goodType && input->dataType() == output->dataType();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
|
||||
const int kD = INT_ARG(0); // filter(kernel) depth
|
||||
const int kH = INT_ARG(1); // filter(kernel) height
|
||||
const int kW = INT_ARG(2); // filter(kernel) width
|
||||
const int sD = INT_ARG(3); // strides depth
|
||||
const int sH = INT_ARG(4); // strides height
|
||||
const int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
const int dD = INT_ARG(9); // dilations depth
|
||||
const int dH = INT_ARG(10); // dilations height
|
||||
const int dW = INT_ARG(11); // dilations width
|
||||
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
// const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging
|
||||
const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CUDA) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
|
||||
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||
|
||||
return goodType && (input->dataType() == gradO->dataType())
|
||||
&& (input->dataType() == gradI->dataType())
|
||||
&& shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo());
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -30,111 +30,231 @@
|
|||
using namespace dnnl;
|
||||
using namespace samediff;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
|
||||
input->rankOf());
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
auto argI = *(block.getIArguments());
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
|
||||
input->rankOf());
|
||||
|
||||
const auto kH = INT_ARG(0);
|
||||
const auto kW = INT_ARG(1);
|
||||
const auto sH = INT_ARG(2);
|
||||
const auto sW = INT_ARG(3);
|
||||
int pH = INT_ARG(4);
|
||||
int pW = INT_ARG(5);
|
||||
const auto dH = INT_ARG(6);
|
||||
const auto dW = INT_ARG(7);
|
||||
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
||||
const auto extraParam0 = INT_ARG(9);
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
auto argI = *(block.getIArguments());
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
|
||||
dH, dW);
|
||||
const auto kH = INT_ARG(0);
|
||||
const auto kW = INT_ARG(1);
|
||||
const auto sH = INT_ARG(2);
|
||||
const auto sW = INT_ARG(3);
|
||||
int pH = INT_ARG(4);
|
||||
int pW = INT_ARG(5);
|
||||
const auto dH = INT_ARG(6);
|
||||
const auto dW = INT_ARG(7);
|
||||
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
||||
const auto extraParam0 = INT_ARG(9);
|
||||
|
||||
int oH = 0;
|
||||
int oW = 0;
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
|
||||
dH, dW);
|
||||
|
||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
int oH = 0;
|
||||
int oW = 0;
|
||||
|
||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
if (!isNCHW) {
|
||||
input = new NDArray(
|
||||
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
output = new NDArray(
|
||||
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||
|
||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
if (isSameMode)
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
const int bS = input->sizeAt(0);
|
||||
const int iC = input->sizeAt(1);
|
||||
const int oC = output->sizeAt(1);
|
||||
|
||||
auto poolingMode = PoolingType::AVG_POOL;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
||||
algorithm,
|
||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||
&user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||
pool_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
auto pool_src_memory = user_src_memory;
|
||||
dnnl::stream stream(engine);
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
auto pool_dst_memory = user_dst_memory;
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory}});
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||
}
|
||||
stream.wait();
|
||||
|
||||
//streams[0].submitAndWait();
|
||||
|
||||
if (!isNCHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
}
|
||||
if (!isNCHW) {
|
||||
input = new NDArray(
|
||||
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
output = new NDArray(
|
||||
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
if (isSameMode)
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
const int bS = input->sizeAt(0);
|
||||
const int iC = input->sizeAt(1);
|
||||
const int oC = output->sizeAt(1);
|
||||
|
||||
auto poolingMode = PoolingType::AVG_POOL;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
||||
algorithm,
|
||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||
&user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||
pool_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
auto pool_src_memory = user_src_memory;
|
||||
dnnl::stream stream(engine);
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
auto pool_dst_memory = user_dst_memory;
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory}});
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||
}
|
||||
stream.wait();
|
||||
|
||||
//streams[0].submitAndWait();
|
||||
|
||||
if (!isNCHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(
|
||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto gradO = INPUT_VARIABLE(
|
||||
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(
|
||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
|
||||
int kH = INT_ARG(0); // filter(kernel) height
|
||||
int kW = INT_ARG(1); // filter(kernel) width
|
||||
int sH = INT_ARG(2); // strides height
|
||||
int sW = INT_ARG(3); // strides width
|
||||
int pH = INT_ARG(4); // paddings height
|
||||
int pW = INT_ARG(5); // paddings width
|
||||
int dH = INT_ARG(6); // dilations height
|
||||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int extraParam0 = INT_ARG(9);
|
||||
int isNCHW =
|
||||
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
|
||||
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(
|
||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(
|
||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
|
||||
if (!isNCHW) {
|
||||
input = new NDArray(input->permute(
|
||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradI = new NDArray(gradI->permute(
|
||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradO = new NDArray(gradO->permute(
|
||||
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
if (isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
auto poolingMode = PoolingType::AVG_POOL;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||
&user_diff_src_md, &user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
||||
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||
pool_padding_r);
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
|
||||
pool_kernel, pool_padding, pool_padding_r);
|
||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||
auto poolB_src_memory = userB_src_memory;
|
||||
dnnl::stream stream(engine);
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
auto poolB_dst_memory = userB_dst_memory;
|
||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||
}
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||
}
|
||||
stream.wait();
|
||||
|
||||
if (!isNCHW) {
|
||||
delete input;
|
||||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(avgpool2d_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,149 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author saudet
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
#include <ops/declarable/OpRegistrator.h>
|
||||
#include <platform_boilerplate.h>
|
||||
|
||||
#include <helpers/MKLDNNStream.h>
|
||||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(
|
||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto gradO = INPUT_VARIABLE(
|
||||
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(
|
||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
|
||||
int kH = INT_ARG(0); // filter(kernel) height
|
||||
int kW = INT_ARG(1); // filter(kernel) width
|
||||
int sH = INT_ARG(2); // strides height
|
||||
int sW = INT_ARG(3); // strides width
|
||||
int pH = INT_ARG(4); // paddings height
|
||||
int pW = INT_ARG(5); // paddings width
|
||||
int dH = INT_ARG(6); // dilations height
|
||||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int extraParam0 = INT_ARG(9);
|
||||
int isNCHW =
|
||||
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
|
||||
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(
|
||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(
|
||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
|
||||
if (!isNCHW) {
|
||||
input = new NDArray(input->permute(
|
||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradI = new NDArray(gradI->permute(
|
||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradO = new NDArray(gradO->permute(
|
||||
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
if (isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
auto poolingMode = PoolingType::AVG_POOL;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||
&user_diff_src_md, &user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
||||
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||
pool_padding_r);
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
|
||||
pool_kernel, pool_padding, pool_padding_r);
|
||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||
auto poolB_src_memory = userB_src_memory;
|
||||
dnnl::stream stream(engine);
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
auto poolB_dst_memory = userB_dst_memory;
|
||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||
}
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||
}
|
||||
stream.wait();
|
||||
|
||||
if (!isNCHW) {
|
||||
delete input;
|
||||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(avgpool2d_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -34,24 +34,23 @@ namespace ops {
|
|||
namespace platforms {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
|
||||
static void conv2dMKLDNN(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
|
||||
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
|
||||
const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode,
|
||||
const int isNCHW) {
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW,
|
||||
indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
|
||||
empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
||||
empty);
|
||||
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||
|
||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
|
||||
bias, output,
|
||||
|
@ -61,13 +60,12 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
|
|||
&user_bias_md, &user_dst_md,
|
||||
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
||||
|
||||
auto conv_desc = bias != nullptr
|
||||
? convolution_forward::desc(prop_kind::forward,
|
||||
auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward,
|
||||
algorithm::convolution_auto, conv_src_md,
|
||||
conv_weights_md, conv_bias_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
conv_padding_r)
|
||||
: convolution_forward::desc(prop_kind::forward,
|
||||
: convolution_forward::desc(prop_kind::forward,
|
||||
algorithm::convolution_auto, conv_src_md,
|
||||
conv_weights_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
|
@ -112,6 +110,135 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
|
|||
stream.wait();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
static void conv2dBpMKLDNN(nd4j::graph::Context &block,
|
||||
const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO,
|
||||
NDArray *gradI, NDArray *gradW, NDArray *gradB,
|
||||
const int kH, const int kW, const int sH,const int sW, int pH, int pW, const int dH, const int dW,
|
||||
const int paddingMode, const int isNCHW) {
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||
|
||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
|
||||
gradB, gradO,
|
||||
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
||||
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
|
||||
&user_src_md, &user_diff_src_md, &user_weights_md,
|
||||
&user_diff_weights_md, &user_bias_md, &user_dst_md,
|
||||
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
||||
auto conv_desc = gradB != nullptr
|
||||
? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
||||
: convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||
|
||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine( LaunchContext::defaultContext()->engine()));
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
if (gradW != nullptr) {
|
||||
auto convW_desc = gradB != nullptr ? convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
||||
: convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||
|
||||
|
||||
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc);
|
||||
|
||||
auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||
auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
|
||||
auto userW_dst_memory = dnnl::memory(user_dst_md, engine,const_cast<NDArray *>(gradO)->buffer());
|
||||
|
||||
auto convW_src_memory = userW_src_memory;
|
||||
|
||||
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
||||
convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
|
||||
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,convW_src_memory);
|
||||
}
|
||||
|
||||
auto convW_weights_memory = userW_weights_memory;
|
||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||
convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine);
|
||||
}
|
||||
|
||||
auto convW_dst_memory = userW_dst_memory;
|
||||
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
||||
convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory);
|
||||
}
|
||||
|
||||
if (gradB != nullptr) {
|
||||
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer());
|
||||
|
||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||
{{DNNL_ARG_SRC, convW_src_memory},
|
||||
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
||||
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
||||
{DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
|
||||
}
|
||||
else {
|
||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||
{{DNNL_ARG_SRC, convW_src_memory},
|
||||
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
||||
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
||||
}
|
||||
|
||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
|
||||
userW_weights_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
if (gradI != nullptr) {
|
||||
|
||||
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||
|
||||
|
||||
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc);
|
||||
auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userI_weights_memory = dnnl::memory(user_weights_md, engine,const_cast<NDArray *>(weights)->buffer());
|
||||
auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast<NDArray *>(gradO)->buffer());
|
||||
|
||||
auto convI_src_memory = userI_src_memory;
|
||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||
convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
|
||||
auto convI_weights_memory = userI_weights_memory;
|
||||
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
||||
convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
|
||||
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory);
|
||||
}
|
||||
|
||||
auto convI_dst_memory = userI_dst_memory;
|
||||
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
||||
convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory);
|
||||
}
|
||||
|
||||
convolution_backward_data(convI_prim_desc).execute(stream,
|
||||
{{DNNL_ARG_DIFF_DST, convI_dst_memory},
|
||||
{DNNL_ARG_WEIGHTS, convI_weights_memory},
|
||||
{DNNL_ARG_DIFF_SRC, convI_src_memory}});
|
||||
|
||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(conv2d, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
|
@ -132,7 +259,7 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) {
|
|||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
|
||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
|
||||
|
||||
conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
conv2dMKLDNN(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -152,6 +279,7 @@ PLATFORM_CHECK(conv2d, ENGINE_CPU) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
@ -172,158 +300,11 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
|
|||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||
"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",
|
||||
input->rankOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == 4, 0,
|
||||
"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",
|
||||
weights->rankOf());
|
||||
REQUIRE_TRUE(gradO->rankOf() == 4, 0,
|
||||
"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",
|
||||
gradO->rankOf());
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",input->rankOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",weights->rankOf());
|
||||
REQUIRE_TRUE(gradO->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",gradO->rankOf());
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
||||
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
||||
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
|
||||
gradB, gradO,
|
||||
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
||||
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
|
||||
&user_src_md, &user_diff_src_md, &user_weights_md,
|
||||
&user_diff_weights_md, &user_bias_md, &user_dst_md,
|
||||
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
||||
auto conv_desc = gradB != nullptr
|
||||
? convolution_forward::desc(prop_kind::forward,
|
||||
algorithm::convolution_auto, conv_src_md,
|
||||
conv_weights_md, conv_bias_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
conv_padding_r)
|
||||
: convolution_forward::desc(prop_kind::forward,
|
||||
algorithm::convolution_auto, conv_src_md,
|
||||
conv_weights_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
conv_padding_r);
|
||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
|
||||
LaunchContext::defaultContext()->engine()));
|
||||
if (gradW != nullptr) {
|
||||
auto convW_desc = gradB != nullptr
|
||||
? convolution_backward_weights::desc(
|
||||
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
||||
: convolution_backward_weights::desc(
|
||||
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
|
||||
conv_prim_desc);
|
||||
auto userW_src_memory = dnnl::memory(user_src_md, engine,
|
||||
const_cast<NDArray *>(input)->buffer());
|
||||
auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
|
||||
auto userW_dst_memory = dnnl::memory(user_dst_md, engine,
|
||||
const_cast<NDArray *>(gradO)->buffer());
|
||||
|
||||
auto convW_src_memory = userW_src_memory;
|
||||
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
||||
convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
|
||||
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
|
||||
convW_src_memory);
|
||||
}
|
||||
|
||||
auto convW_weights_memory = userW_weights_memory;
|
||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||
convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine);
|
||||
}
|
||||
|
||||
auto convW_dst_memory = userW_dst_memory;
|
||||
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
||||
convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
|
||||
convW_dst_memory);
|
||||
}
|
||||
|
||||
if (gradB != nullptr) {
|
||||
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine,
|
||||
gradB->buffer());
|
||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||
{{DNNL_ARG_SRC, convW_src_memory},
|
||||
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
||||
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
||||
{DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
|
||||
} else {
|
||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||
{{DNNL_ARG_SRC, convW_src_memory},
|
||||
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
||||
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
||||
}
|
||||
|
||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
|
||||
userW_weights_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
if (gradI != nullptr) {
|
||||
auto convI_desc =
|
||||
convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md,
|
||||
conv_weights_md, conv_dst_md, conv_strides, conv_dilation,
|
||||
conv_padding, conv_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
|
||||
conv_prim_desc);
|
||||
auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userI_weights_memory = dnnl::memory(user_weights_md, engine,
|
||||
const_cast<NDArray *>(weights)->buffer());
|
||||
auto userI_dst_memory = dnnl::memory(user_dst_md, engine,
|
||||
const_cast<NDArray *>(gradO)->buffer());
|
||||
|
||||
auto convI_src_memory = userI_src_memory;
|
||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||
convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
|
||||
auto convI_weights_memory = userI_weights_memory;
|
||||
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
||||
convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
|
||||
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
|
||||
convI_weights_memory);
|
||||
}
|
||||
|
||||
auto convI_dst_memory = userI_dst_memory;
|
||||
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
||||
convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
|
||||
convI_dst_memory);
|
||||
}
|
||||
|
||||
convolution_backward_data(convI_prim_desc).execute(stream,
|
||||
{{DNNL_ARG_DIFF_DST, convI_dst_memory},
|
||||
{DNNL_ARG_WEIGHTS, convI_weights_memory},
|
||||
{DNNL_ARG_DIFF_SRC, convI_src_memory}});
|
||||
|
||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
|
||||
userI_src_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
};
|
||||
conv2dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -34,62 +34,23 @@ namespace ops {
|
|||
namespace platforms {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(
|
||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto output = OUTPUT_VARIABLE(
|
||||
0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||
"CUSTOM CONV3D OP: rank of input array must be equal to 5, but got %i instead !",
|
||||
input->rankOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == 5, 0,
|
||||
"CUSTOM CONV3D OP: rank of weights array must be equal to 5, but got %i instead !",
|
||||
weights->rankOf());
|
||||
|
||||
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
||||
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
||||
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
||||
int sD = INT_ARG(3); // strides depth
|
||||
int sH = INT_ARG(4); // strides height
|
||||
int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
int dD = INT_ARG(9); // dilations depth
|
||||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||
int isNCDHW =
|
||||
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
static void conv3dMKLDNN(nd4j::graph::Context &block,
|
||||
const NDArray *input, const NDArray *weights, const NDArray *bias,
|
||||
NDArray *output,
|
||||
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW,
|
||||
const int paddingMode, const int isNCDHW) {
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
|
||||
"CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !",
|
||||
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
|
||||
"CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
|
||||
oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
if (isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
|
||||
empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
||||
empty);
|
||||
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( empty);
|
||||
|
||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode,
|
||||
isNCDHW,
|
||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights,
|
||||
nullptr, bias, output,
|
||||
|
@ -98,151 +59,73 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
|||
&user_src_md, nullptr, &user_weights_md, nullptr,
|
||||
&user_bias_md, &user_dst_md,
|
||||
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
||||
auto conv_desc = bias != nullptr
|
||||
? convolution_forward::desc(prop_kind::forward,
|
||||
algorithm::convolution_auto, conv_src_md,
|
||||
conv_weights_md, conv_bias_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
conv_padding_r)
|
||||
: convolution_forward::desc(prop_kind::forward,
|
||||
algorithm::convolution_auto, conv_src_md,
|
||||
conv_weights_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
conv_padding_r);
|
||||
auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
||||
: convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||
auto user_weights_memory = dnnl::memory(user_weights_md, engine,
|
||||
const_cast<NDArray *>(weights)->buffer());
|
||||
auto user_weights_memory = dnnl::memory(user_weights_md, engine, const_cast<NDArray *>(weights)->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
|
||||
auto conv_src_memory = user_src_memory;
|
||||
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
|
||||
}
|
||||
|
||||
auto conv_weights_memory = user_weights_memory;
|
||||
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
|
||||
conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine);
|
||||
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
|
||||
conv_weights_memory);
|
||||
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, conv_weights_memory);
|
||||
}
|
||||
|
||||
auto conv_dst_memory = user_dst_memory;
|
||||
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
conv_dst_memory = dnnl::memory(conv_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
|
||||
if (bias != nullptr) {
|
||||
auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->buffer());
|
||||
auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->getBuffer());
|
||||
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
|
||||
{DNNL_ARG_WEIGHTS, conv_weights_memory},
|
||||
{DNNL_ARG_BIAS, conv_bias_memory},
|
||||
{DNNL_ARG_DST, conv_dst_memory}});
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
|
||||
{DNNL_ARG_WEIGHTS, conv_weights_memory},
|
||||
{DNNL_ARG_DST, conv_dst_memory}});
|
||||
}
|
||||
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
|
||||
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc())
|
||||
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
|
||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||
if (::optimalLevel() < 2)
|
||||
return false;
|
||||
|
||||
auto input = INPUT_VARIABLE(
|
||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto output = OUTPUT_VARIABLE(
|
||||
0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output});
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(
|
||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(
|
||||
1); // [kD, kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
||||
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(
|
||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
auto gradW = OUTPUT_VARIABLE(
|
||||
1); // [kD, kH, kW, iC, oC] always
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||
"CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !",
|
||||
input->rankOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == 5, 0,
|
||||
"CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, but got %i instead !",
|
||||
weights->rankOf());
|
||||
REQUIRE_TRUE(gradO->rankOf() == 5, 0,
|
||||
"CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !",
|
||||
gradO->rankOf());
|
||||
|
||||
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
||||
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
||||
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
||||
int sD = INT_ARG(3); // strides depth
|
||||
int sH = INT_ARG(4); // strides height
|
||||
int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
int dD = INT_ARG(9); // dilations depth
|
||||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
int isNDHWC =
|
||||
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
static void conv3dBpMKLDNN(nd4j::graph::Context &block,
|
||||
const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO,
|
||||
NDArray *gradI, NDArray *gradW, NDArray *gradB,
|
||||
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW,
|
||||
const int paddingMode, const int isNCDHW) {
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
int trueoD, trueoH, trueoW; // true output depth/height/width
|
||||
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH,
|
||||
dW, iD, iH, iW, isSameMode);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||
{bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||
"CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !",
|
||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
|
||||
"CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !",
|
||||
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
|
||||
"CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
|
||||
oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
||||
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
||||
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||
|
||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
|
||||
isNDHWC,
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode,
|
||||
isNCDHW,
|
||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights,
|
||||
gradW, gradB, gradO,
|
||||
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
||||
|
@ -250,43 +133,30 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
|||
&user_src_md, &user_diff_src_md, &user_weights_md,
|
||||
&user_diff_weights_md, &user_bias_md, &user_dst_md,
|
||||
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
||||
auto conv_desc = gradB != nullptr
|
||||
? convolution_forward::desc(prop_kind::forward,
|
||||
algorithm::convolution_auto, conv_src_md,
|
||||
conv_weights_md, conv_bias_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
conv_padding_r)
|
||||
: convolution_forward::desc(prop_kind::forward,
|
||||
algorithm::convolution_auto, conv_src_md,
|
||||
conv_weights_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
conv_padding_r);
|
||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
|
||||
LaunchContext::defaultContext()->engine()));
|
||||
if (gradW != nullptr) {
|
||||
auto convW_desc = gradB != nullptr
|
||||
? convolution_backward_weights::desc(
|
||||
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
||||
: convolution_backward_weights::desc(
|
||||
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
|
||||
conv_prim_desc);
|
||||
auto userW_src_memory = dnnl::memory(user_src_md, engine,
|
||||
const_cast<NDArray *>(input)->buffer());
|
||||
auto conv_desc = gradB != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
||||
: convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||
|
||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()));
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
if (gradW != nullptr) {
|
||||
|
||||
auto convW_desc = gradB != nullptr ? convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
||||
: convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc);
|
||||
|
||||
auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||
auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
|
||||
auto userW_dst_memory = dnnl::memory(user_dst_md, engine,
|
||||
const_cast<NDArray *>(gradO)->buffer());
|
||||
auto userW_dst_memory = dnnl::memory(user_dst_md, engine, const_cast<NDArray *>(gradO)->buffer());
|
||||
|
||||
auto convW_src_memory = userW_src_memory;
|
||||
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
||||
convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
|
||||
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
|
||||
convW_src_memory);
|
||||
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, convW_src_memory);
|
||||
}
|
||||
|
||||
auto convW_weights_memory = userW_weights_memory;
|
||||
|
@ -297,65 +167,53 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
|||
auto convW_dst_memory = userW_dst_memory;
|
||||
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
||||
convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
|
||||
convW_dst_memory);
|
||||
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory);
|
||||
}
|
||||
|
||||
if (gradB != nullptr) {
|
||||
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine,
|
||||
gradB->buffer());
|
||||
|
||||
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer());
|
||||
|
||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||
{{DNNL_ARG_SRC, convW_src_memory},
|
||||
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
||||
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
||||
{DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||
{{DNNL_ARG_SRC, convW_src_memory},
|
||||
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
||||
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
||||
}
|
||||
|
||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
|
||||
userW_weights_memory);
|
||||
}
|
||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc())
|
||||
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, userW_weights_memory);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
if (gradI != nullptr) {
|
||||
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto,
|
||||
conv_diff_src_md, conv_weights_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
conv_padding_r);
|
||||
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
|
||||
conv_prim_desc);
|
||||
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc);
|
||||
auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userI_weights_memory = dnnl::memory(user_weights_md, engine,
|
||||
const_cast<NDArray *>(weights)->buffer());
|
||||
auto userI_dst_memory = dnnl::memory(user_dst_md, engine,
|
||||
const_cast<NDArray *>(gradO)->buffer());
|
||||
auto userI_weights_memory = dnnl::memory(user_weights_md, engine, const_cast<NDArray *>(weights)->buffer());
|
||||
auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast<NDArray *>(gradO)->buffer());
|
||||
|
||||
auto convI_src_memory = userI_src_memory;
|
||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc())
|
||||
convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
|
||||
auto convI_weights_memory = userI_weights_memory;
|
||||
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
||||
convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
|
||||
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
|
||||
convI_weights_memory);
|
||||
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory);
|
||||
}
|
||||
|
||||
auto convI_dst_memory = userI_dst_memory;
|
||||
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
||||
convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
|
||||
convI_dst_memory);
|
||||
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory);
|
||||
}
|
||||
|
||||
convolution_backward_data(convI_prim_desc).execute(stream,
|
||||
|
@ -363,30 +221,128 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
|||
{DNNL_ARG_WEIGHTS, convI_weights_memory},
|
||||
{DNNL_ARG_DIFF_SRC, convI_src_memory}});
|
||||
|
||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
|
||||
userI_src_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc())
|
||||
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory);
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
|
||||
|
||||
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
||||
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
||||
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
||||
int sD = INT_ARG(3); // strides depth
|
||||
int sH = INT_ARG(4); // strides height
|
||||
int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
int dD = INT_ARG(9); // dilations depth
|
||||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
if (paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
conv3dMKLDNN(block, input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
|
||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||
if (::optimalLevel() < 2)
|
||||
return false;
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
|
||||
REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf());
|
||||
|
||||
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
||||
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
||||
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
||||
int sD = INT_ARG(3); // strides depth
|
||||
int sH = INT_ARG(4); // strides height
|
||||
int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
int dD = INT_ARG(9); // dilations depth
|
||||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
if(paddingMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
int trueoD, trueoH, trueoW; // true output depth/height/width
|
||||
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
conv3dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(
|
||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(
|
||||
1); // [kD, kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
||||
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(
|
||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
auto gradW = OUTPUT_VARIABLE(
|
||||
1); // [kD, kH, kW, iC, oC] always
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
return block.isUseMKLDNN() &&
|
||||
|
|
|
@ -177,7 +177,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||
static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const int paddingMode) {
|
||||
|
||||
|
@ -492,7 +492,7 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) {
|
|||
gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
}
|
||||
|
||||
deconv2dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode);
|
||||
deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode);
|
||||
|
||||
delete weights;
|
||||
delete gradW;
|
||||
|
|
|
@ -421,7 +421,7 @@ PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) {
|
|||
return block.isUseMKLDNN() && mC == 1 &&
|
||||
(
|
||||
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
||||
(xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF) ||
|
||||
(xType==DataType::BFLOAT16 && wType==DataType::BFLOAT16 && bType==DataType::BFLOAT16 && zType==DataType::BFLOAT16) ||
|
||||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -29,117 +29,258 @@
|
|||
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
|
||||
input->rankOf());
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
auto argI = *(block.getIArguments());
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
|
||||
input->rankOf());
|
||||
|
||||
const auto kH = INT_ARG(0);
|
||||
const auto kW = INT_ARG(1);
|
||||
const auto sH = INT_ARG(2);
|
||||
const auto sW = INT_ARG(3);
|
||||
int pH = INT_ARG(4);
|
||||
int pW = INT_ARG(5);
|
||||
const auto dH = INT_ARG(6);
|
||||
const auto dW = INT_ARG(7);
|
||||
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
auto argI = *(block.getIArguments());
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
|
||||
dH, dW);
|
||||
const auto kH = INT_ARG(0);
|
||||
const auto kW = INT_ARG(1);
|
||||
const auto sH = INT_ARG(2);
|
||||
const auto sW = INT_ARG(3);
|
||||
int pH = INT_ARG(4);
|
||||
int pW = INT_ARG(5);
|
||||
const auto dH = INT_ARG(6);
|
||||
const auto dW = INT_ARG(7);
|
||||
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
||||
|
||||
int oH = 0;
|
||||
int oW = 0;
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
|
||||
dH, dW);
|
||||
|
||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
int oH = 0;
|
||||
int oW = 0;
|
||||
|
||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
if (!isNCHW) {
|
||||
input = new NDArray(
|
||||
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
output = new NDArray(
|
||||
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||
|
||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
if (isSameMode)
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
const int bS = input->sizeAt(0);
|
||||
const int iC = input->sizeAt(1);
|
||||
const int oC = output->sizeAt(1);
|
||||
|
||||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
int extraParam0 = 1;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
||||
algorithm,
|
||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||
&user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||
pool_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
|
||||
auto pool_src_memory = user_src_memory;
|
||||
dnnl::stream stream(engine);
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = user_dst_memory;
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory}});
|
||||
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
|
||||
if (!isNCHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(maxpool2d, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
}
|
||||
if (!isNCHW) {
|
||||
input = new NDArray(
|
||||
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
output = new NDArray(
|
||||
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
if (isSameMode)
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
const int bS = input->sizeAt(0);
|
||||
const int iC = input->sizeAt(1);
|
||||
const int oC = output->sizeAt(1);
|
||||
|
||||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
int extraParam0 = 1;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
||||
algorithm,
|
||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||
&user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||
pool_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
|
||||
auto pool_src_memory = user_src_memory;
|
||||
dnnl::stream stream(engine);
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = user_dst_memory;
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory}});
|
||||
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
|
||||
if (!isNCHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(maxpool2d, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
|
||||
int kH = INT_ARG(0); // filter(kernel) height
|
||||
int kW = INT_ARG(1); // filter(kernel) width
|
||||
int sH = INT_ARG(2); // strides height
|
||||
int sW = INT_ARG(3); // strides width
|
||||
int pH = INT_ARG(4); // paddings height
|
||||
int pW = INT_ARG(5); // paddings width
|
||||
int dH = INT_ARG(6); // dilations height
|
||||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int extraParam0 = INT_ARG(9);
|
||||
int isNCHW =
|
||||
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
|
||||
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(
|
||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(
|
||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
|
||||
if (!isNCHW) {
|
||||
input = new NDArray(input->permute(
|
||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradI = new NDArray(gradI->permute(
|
||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradO = new NDArray(gradO->permute(
|
||||
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
if (isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||
&user_diff_src_md, &user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
// input is sometimes null, so we can't rely on pool_src_md being valid
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
||||
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||
pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
|
||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||
|
||||
auto poolB_src_memory = userB_src_memory;
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
|
||||
auto poolB_dst_memory = userB_dst_memory;
|
||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||
}
|
||||
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto pool_src_memory = user_src_memory;
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
||||
// probably wrong, fix that
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
|
||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
|
||||
if (!isNCHW) {
|
||||
delete input;
|
||||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(maxpool2d_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,174 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author saudet
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
#include <ops/declarable/OpRegistrator.h>
|
||||
#include <platform_boilerplate.h>
|
||||
|
||||
#include <helpers/MKLDNNStream.h>
|
||||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(
|
||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto gradO = INPUT_VARIABLE(
|
||||
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(
|
||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||
|
||||
int kH = INT_ARG(0); // filter(kernel) height
|
||||
int kW = INT_ARG(1); // filter(kernel) width
|
||||
int sH = INT_ARG(2); // strides height
|
||||
int sW = INT_ARG(3); // strides width
|
||||
int pH = INT_ARG(4); // paddings height
|
||||
int pW = INT_ARG(5); // paddings width
|
||||
int dH = INT_ARG(6); // dilations height
|
||||
int dW = INT_ARG(7); // dilations width
|
||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||
int extraParam0 = INT_ARG(9);
|
||||
int isNCHW =
|
||||
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
|
||||
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(
|
||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(
|
||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
|
||||
if (!isNCHW) {
|
||||
input = new NDArray(input->permute(
|
||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradI = new NDArray(gradI->permute(
|
||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradO = new NDArray(gradO->permute(
|
||||
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
if (isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||
true,
|
||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||
&user_diff_src_md, &user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
// input is sometimes null, so we can't rely on pool_src_md being valid
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
||||
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||
pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
|
||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||
|
||||
auto poolB_src_memory = userB_src_memory;
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
|
||||
auto poolB_dst_memory = userB_dst_memory;
|
||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||
}
|
||||
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto pool_src_memory = user_src_memory;
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
||||
// probably wrong, fix that
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
|
||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
|
||||
if (!isNCHW) {
|
||||
delete input;
|
||||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(maxpool2d_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -28,124 +28,273 @@
|
|||
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(
|
||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto output = OUTPUT_VARIABLE(
|
||||
0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
int kD = INT_ARG(0); // filter(kernel) depth
|
||||
int kH = INT_ARG(1); // filter(kernel) height
|
||||
int kW = INT_ARG(2); // filter(kernel) width
|
||||
int sD = INT_ARG(3); // strides depth
|
||||
int sH = INT_ARG(4); // strides height
|
||||
int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
int dD = INT_ARG(9); // dilations depth
|
||||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(
|
||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto output = OUTPUT_VARIABLE(
|
||||
0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||
"MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
|
||||
input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
||||
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
int kD = INT_ARG(0); // filter(kernel) depth
|
||||
int kH = INT_ARG(1); // filter(kernel) height
|
||||
int kW = INT_ARG(2); // filter(kernel) width
|
||||
int sD = INT_ARG(3); // strides depth
|
||||
int sH = INT_ARG(4); // strides height
|
||||
int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
int dD = INT_ARG(9); // dilations depth
|
||||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||
"MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
|
||||
input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
||||
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
|
||||
"MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
|
||||
expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
|
||||
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
if (!isNCDHW) {
|
||||
input = new NDArray(
|
||||
input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
output = new NDArray(
|
||||
output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
}
|
||||
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
|
||||
"MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
|
||||
expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
|
||||
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
||||
|
||||
if (isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
|
||||
dW);
|
||||
|
||||
|
||||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
auto extraParam0 = 1;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||
extraParam0, true,
|
||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
|
||||
algorithm,
|
||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||
&user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||
pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
|
||||
auto pool_src_memory = user_src_memory;
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = user_dst_memory;
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory}});
|
||||
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
||||
if (!isNCDHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
}
|
||||
if (!isNCDHW) {
|
||||
input = new NDArray(
|
||||
input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
output = new NDArray(
|
||||
output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
}
|
||||
|
||||
if (isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
|
||||
dW);
|
||||
|
||||
|
||||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
auto extraParam0 = 1;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||
extraParam0, true,
|
||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
|
||||
algorithm,
|
||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||
&user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||
pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||
|
||||
auto pool_src_memory = user_src_memory;
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = user_dst_memory;
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
}
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory}});
|
||||
|
||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
||||
if (!isNCDHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
|
||||
const int kD = INT_ARG(0); // filter(kernel) depth
|
||||
const int kH = INT_ARG(1); // filter(kernel) height
|
||||
const int kW = INT_ARG(2); // filter(kernel) width
|
||||
const int sD = INT_ARG(3); // strides depth
|
||||
const int sH = INT_ARG(4); // strides height
|
||||
const int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
const int dD = INT_ARG(9); // dilations depth
|
||||
const int dH = INT_ARG(10); // dilations height
|
||||
const int dW = INT_ARG(11); // dilations width
|
||||
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||
"MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
||||
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||
{bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||
"MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||
"MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if (!isNCDHW) {
|
||||
input = new NDArray(input->permute(
|
||||
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradI = new NDArray(gradI->permute(
|
||||
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradO = new NDArray(gradO->permute(
|
||||
{0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
}
|
||||
|
||||
if (isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
|
||||
dW);
|
||||
|
||||
|
||||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
auto extraParam0 = 1;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||
extraParam0, true,
|
||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
|
||||
algorithm,
|
||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||
&user_diff_src_md, &user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
// input is sometimes null, so we can't rely on pool_src_md being valid
|
||||
if (input->buffer() == nullptr) {
|
||||
pool_src_md = pool_diff_src_md;
|
||||
user_src_md = user_diff_src_md;
|
||||
}
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
|
||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||
auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||
|
||||
auto poolB_src_memory = userB_src_memory;
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
|
||||
auto poolB_dst_memory = userB_dst_memory;
|
||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||
}
|
||||
|
||||
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
|
||||
auto pool_src_memory = user_src_memory;
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
|
||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
|
||||
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
|
||||
if (!isNCDHW) {
|
||||
delete input;
|
||||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,181 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
#include <ops/declarable/OpRegistrator.h>
|
||||
#include <platform_boilerplate.h>
|
||||
|
||||
#include <helpers/MKLDNNStream.h>
|
||||
#include "mkldnnUtils.h"
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(
|
||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto gradO = INPUT_VARIABLE(
|
||||
1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(
|
||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
|
||||
const int kD = INT_ARG(0); // filter(kernel) depth
|
||||
const int kH = INT_ARG(1); // filter(kernel) height
|
||||
const int kW = INT_ARG(2); // filter(kernel) width
|
||||
const int sD = INT_ARG(3); // strides depth
|
||||
const int sH = INT_ARG(4); // strides height
|
||||
const int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
const int dD = INT_ARG(9); // dilations depth
|
||||
const int dH = INT_ARG(10); // dilations height
|
||||
const int dW = INT_ARG(11); // dilations width
|
||||
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||
"MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
||||
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||
{bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||
"MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||
"MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if (!isNCDHW) {
|
||||
input = new NDArray(input->permute(
|
||||
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradI = new NDArray(gradI->permute(
|
||||
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradO = new NDArray(gradO->permute(
|
||||
{0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
}
|
||||
|
||||
if (isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
|
||||
dW);
|
||||
|
||||
|
||||
auto poolingMode = PoolingType::MAX_POOL;
|
||||
auto extraParam0 = 1;
|
||||
|
||||
dnnl_memory_desc_t empty;
|
||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||
dnnl::algorithm algorithm;
|
||||
|
||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||
extraParam0, true,
|
||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
|
||||
algorithm,
|
||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||
&user_diff_src_md, &user_dst_md,
|
||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
// input is sometimes null, so we can't rely on pool_src_md being valid
|
||||
if (input->buffer() == nullptr) {
|
||||
pool_src_md = pool_diff_src_md;
|
||||
user_src_md = user_diff_src_md;
|
||||
}
|
||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||
|
||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||
|
||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||
auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||
|
||||
auto poolB_src_memory = userB_src_memory;
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||
}
|
||||
|
||||
auto poolB_dst_memory = userB_dst_memory;
|
||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||
}
|
||||
|
||||
|
||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||
|
||||
auto pool_src_memory = user_src_memory;
|
||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||
}
|
||||
|
||||
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
|
||||
|
||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||
{DNNL_ARG_DST, pool_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
|
||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||
|
||||
|
||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||
}
|
||||
|
||||
stream.wait();
|
||||
|
||||
if (!isNCDHW) {
|
||||
delete input;
|
||||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -23,383 +23,388 @@
|
|||
|
||||
using namespace dnnl;
|
||||
|
||||
namespace nd4j {
|
||||
namespace mkldnnUtils {
|
||||
void getMKLDNNMemoryDescPool2d(
|
||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
||||
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
|
||||
dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW };
|
||||
dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW };
|
||||
namespace nd4j {
|
||||
namespace mkldnnUtils {
|
||||
|
||||
pool_strides = { sH, sW };
|
||||
pool_kernel = { kH, kW };
|
||||
pool_padding = { pH, pW };
|
||||
pool_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
||||
(oW - 1) * sW - iW + kW - pW };
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void getMKLDNNMemoryDescPool2d(
|
||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
||||
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
|
||||
dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW };
|
||||
dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW };
|
||||
|
||||
algorithm = poolingMode == 0 ? algorithm::pooling_max
|
||||
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
|
||||
: algorithm::pooling_avg_include_padding;
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||
pool_strides = { sH, sW };
|
||||
pool_kernel = { kH, kW };
|
||||
pool_padding = { pH, pW };
|
||||
pool_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
||||
(oW - 1) * sW - iW + kW - pW };
|
||||
|
||||
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
||||
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
algorithm = poolingMode == 0 ? algorithm::pooling_max
|
||||
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
|
||||
: algorithm::pooling_avg_include_padding;
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
||||
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
||||
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
void getMKLDNNMemoryDescPool3d(
|
||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
||||
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
|
||||
dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
|
||||
dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
|
||||
|
||||
pool_strides = { sD, sH, sW };
|
||||
pool_kernel = { kD, kH, kW };
|
||||
pool_padding = { pD, pH, pW };
|
||||
pool_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
||||
(oH - 1) * sH - iH + kH - pH,
|
||||
(oW - 1) * sW - iW + kW - pW };
|
||||
|
||||
algorithm = poolingMode == 0 ? algorithm::pooling_max
|
||||
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
|
||||
: algorithm::pooling_avg_include_padding;
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any"
|
||||
|
||||
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
||||
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
||||
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
||||
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
void getMKLDNNMemoryDescConv2d(
|
||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
|
||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
||||
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
|
||||
dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW };
|
||||
dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW };
|
||||
dnnl::memory::dims conv_bias_tz = { oC };
|
||||
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
|
||||
|
||||
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||
|
||||
conv_strides = { sH, sW };
|
||||
conv_padding = { pH, pW };
|
||||
conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||
conv_dilation = { dH-1, dW-1};
|
||||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
auto formatw = dnnl::memory::format_tag::hwio;
|
||||
|
||||
if (src != nullptr && conv_src_md != nullptr) {
|
||||
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
||||
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
|
||||
if (weights != nullptr && conv_weights_md != nullptr) {
|
||||
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
|
||||
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3];
|
||||
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2];
|
||||
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
|
||||
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
|
||||
}
|
||||
|
||||
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
||||
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
|
||||
}
|
||||
|
||||
if (bias != nullptr && conv_bias_md != nullptr) {
|
||||
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
|
||||
}
|
||||
|
||||
if (dst != nullptr && conv_dst_md != nullptr) {
|
||||
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
}
|
||||
|
||||
void getMKLDNNMemoryDescConv3d(
|
||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW,
|
||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
||||
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
|
||||
dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
|
||||
dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
|
||||
dnnl::memory::dims conv_bias_tz = { oC };
|
||||
dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
|
||||
|
||||
conv_strides = { sD, sH, sW };
|
||||
conv_padding = { pD, pH, pW };
|
||||
conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||
conv_dilation = { dD-1, dH-1, dW-1};
|
||||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
auto formatw = dnnl::memory::format_tag::dhwio;
|
||||
|
||||
if (src != nullptr && conv_src_md != nullptr) {
|
||||
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
||||
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
|
||||
if (weights != nullptr && conv_weights_md != nullptr) {
|
||||
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
|
||||
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4];
|
||||
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3];
|
||||
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
|
||||
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
|
||||
user_weights_md->data.format_desc.blocking.strides[4] = weights->stridesOf()[2];
|
||||
}
|
||||
|
||||
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
||||
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[4] = diff_weights->stridesOf()[2];
|
||||
}
|
||||
|
||||
if (bias != nullptr && conv_bias_md != nullptr) {
|
||||
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
|
||||
}
|
||||
|
||||
if (dst != nullptr && conv_dst_md != nullptr) {
|
||||
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
// dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
||||
// dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
||||
// const Nd4jLong* shape = src->getShapeInfo();
|
||||
// Nd4jLong rank = shape[0];
|
||||
// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||
// Nd4jLong dim2 = axis >= 2 ? 1 : 2;
|
||||
// Nd4jLong dim3 = axis >= 3 ? 2 : 3;
|
||||
// dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||
|
||||
// auto type = dnnl::memory::data_type::f32;
|
||||
// auto format = dnnl::memory::format_tag::nchw;
|
||||
// auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||
|
||||
// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
|
||||
// *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_src_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||
// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||
// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||
// user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
||||
// }
|
||||
|
||||
// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
|
||||
// *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
||||
// }
|
||||
|
||||
// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
|
||||
// *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_dst_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||
// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||
// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||
// user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
||||
// }
|
||||
// };
|
||||
|
||||
|
||||
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
||||
const Nd4jLong* shape = src->getShapeInfo();
|
||||
long rank = shape[0];
|
||||
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||
long dim2 = axis >= 2 ? 1 : 2;
|
||||
long dim3 = axis >= 3 ? 2 : 3;
|
||||
dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
auto supposed_to_be_any_format = format; // doesn't work with "any"
|
||||
|
||||
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
|
||||
*lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked;
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
|
||||
*lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked;
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
||||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
|
||||
*lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked;
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
||||
}
|
||||
}
|
||||
|
||||
dnnl::engine& getEngine(void *ptr) {
|
||||
auto eng = reinterpret_cast<dnnl::engine*>(ptr);
|
||||
return *eng;
|
||||
}
|
||||
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
||||
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
||||
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
||||
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void getMKLDNNMemoryDescPool3d(
|
||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
||||
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
|
||||
dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
|
||||
dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
|
||||
|
||||
pool_strides = { sD, sH, sW };
|
||||
pool_kernel = { kD, kH, kW };
|
||||
pool_padding = { pD, pH, pW };
|
||||
pool_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
||||
(oH - 1) * sH - iH + kH - pH,
|
||||
(oW - 1) * sW - iW + kW - pW };
|
||||
|
||||
algorithm = poolingMode == 0 ? algorithm::pooling_max
|
||||
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
|
||||
: algorithm::pooling_avg_include_padding;
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any"
|
||||
|
||||
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
||||
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
||||
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
||||
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void getMKLDNNMemoryDescConv2d(
|
||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
|
||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
||||
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
|
||||
dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW };
|
||||
dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW };
|
||||
dnnl::memory::dims conv_bias_tz = { oC };
|
||||
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
|
||||
|
||||
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||
|
||||
conv_strides = { sH, sW };
|
||||
conv_padding = { pH, pW };
|
||||
conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||
conv_dilation = { dH-1, dW-1};
|
||||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
auto formatw = dnnl::memory::format_tag::hwio;
|
||||
|
||||
if (src != nullptr && conv_src_md != nullptr) {
|
||||
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
||||
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
|
||||
if (weights != nullptr && conv_weights_md != nullptr) {
|
||||
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
|
||||
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3];
|
||||
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2];
|
||||
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
|
||||
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
|
||||
}
|
||||
|
||||
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
||||
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
|
||||
}
|
||||
|
||||
if (bias != nullptr && conv_bias_md != nullptr) {
|
||||
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
|
||||
}
|
||||
|
||||
if (dst != nullptr && conv_dst_md != nullptr) {
|
||||
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void getMKLDNNMemoryDescConv3d(
|
||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW,
|
||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
||||
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
||||
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
|
||||
dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
|
||||
dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
|
||||
dnnl::memory::dims conv_bias_tz = { oC };
|
||||
dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
|
||||
|
||||
conv_strides = { sD, sH, sW };
|
||||
conv_padding = { pD, pH, pW };
|
||||
conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||
conv_dilation = { dD-1, dH-1, dW-1};
|
||||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
auto formatw = dnnl::memory::format_tag::dhwio;
|
||||
|
||||
if (src != nullptr && conv_src_md != nullptr) {
|
||||
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
||||
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
|
||||
if (weights != nullptr && conv_weights_md != nullptr) {
|
||||
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
|
||||
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4];
|
||||
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3];
|
||||
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
|
||||
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
|
||||
user_weights_md->data.format_desc.blocking.strides[4] = weights->stridesOf()[2];
|
||||
}
|
||||
|
||||
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
||||
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
|
||||
user_diff_weights_md->data.format_desc.blocking.strides[4] = diff_weights->stridesOf()[2];
|
||||
}
|
||||
|
||||
if (bias != nullptr && conv_bias_md != nullptr) {
|
||||
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
|
||||
}
|
||||
|
||||
if (dst != nullptr && conv_dst_md != nullptr) {
|
||||
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
|
||||
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
// dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
||||
// dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
||||
// const Nd4jLong* shape = src->getShapeInfo();
|
||||
// Nd4jLong rank = shape[0];
|
||||
// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||
// Nd4jLong dim2 = axis >= 2 ? 1 : 2;
|
||||
// Nd4jLong dim3 = axis >= 3 ? 2 : 3;
|
||||
// dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||
|
||||
// auto type = dnnl::memory::data_type::f32;
|
||||
// auto format = dnnl::memory::format_tag::nchw;
|
||||
// auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||
|
||||
// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
|
||||
// *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_src_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||
// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||
// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||
// user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
||||
// }
|
||||
|
||||
// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
|
||||
// *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||
// user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
||||
// }
|
||||
|
||||
// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
|
||||
// *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||
// *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||
// user_dst_md->data.format_kind = dnnl_blocked; // overrides format
|
||||
// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||
// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||
// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||
// user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
||||
// }
|
||||
// };
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
||||
const Nd4jLong* shape = src->getShapeInfo();
|
||||
long rank = shape[0];
|
||||
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||
long dim2 = axis >= 2 ? 1 : 2;
|
||||
long dim3 = axis >= 3 ? 2 : 3;
|
||||
dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||
auto supposed_to_be_any_format = format; // doesn't work with "any"
|
||||
|
||||
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
|
||||
*lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_src_md->data.format_kind = dnnl_blocked;
|
||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
|
||||
*lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_diff_src_md->data.format_kind = dnnl_blocked;
|
||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
||||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
|
||||
*lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||
user_dst_md->data.format_kind = dnnl_blocked;
|
||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
dnnl::engine& getEngine(void *ptr) {
|
||||
auto eng = reinterpret_cast<dnnl::engine*>(ptr);
|
||||
return *eng;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
File diff suppressed because one or more lines are too long
|
@ -970,7 +970,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) {
|
|||
|
||||
x.linspace(1);
|
||||
|
||||
|
||||
nd4j::ops::maxpool2d op;
|
||||
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1});
|
||||
|
||||
|
@ -991,7 +990,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) {
|
|||
|
||||
x.linspace(1);
|
||||
|
||||
|
||||
nd4j::ops::maxpool2d op;
|
||||
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1});
|
||||
|
||||
|
@ -1012,7 +1010,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) {
|
|||
|
||||
x.linspace(1);
|
||||
|
||||
|
||||
nd4j::ops::maxpool2d op;
|
||||
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0});
|
||||
|
||||
|
@ -1467,11 +1464,12 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test1) {
|
|||
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
||||
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
|
||||
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f,
|
||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f,
|
||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f,
|
||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f,
|
||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f,
|
||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f});
|
||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f,
|
||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f,
|
||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f,
|
||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f,
|
||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f});
|
||||
|
||||
input.linspace(1.);
|
||||
gradO.linspace(0.1, 0.1);
|
||||
|
||||
|
|
|
@ -57,6 +57,17 @@ TEST_F(CuDnnTests, helpers_includer) {
|
|||
nd4j::ops::platforms::PLATFORM_depthwise_conv2d_ENGINE_CUDA depthwise_conv2d;
|
||||
nd4j::ops::platforms::PLATFORM_depthwise_conv2d_bp_ENGINE_CUDA depthwise_conv2d_bp;
|
||||
nd4j::ops::platforms::PLATFORM_batchnorm_ENGINE_CUDA batchnorm;
|
||||
nd4j::ops::platforms::PLATFORM_batchnorm_bp_ENGINE_CUDA batchnorm_bp;
|
||||
nd4j::ops::platforms::PLATFORM_avgpool2d_ENGINE_CUDA avgpool2d;
|
||||
nd4j::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CUDA avgpool2d_bp;
|
||||
nd4j::ops::platforms::PLATFORM_maxpool2d_ENGINE_CUDA maxpool2d;
|
||||
nd4j::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CUDA maxpool2d_bp;
|
||||
nd4j::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CUDA avgpool3dnew;
|
||||
nd4j::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CUDA avgpool3dnew_bp;
|
||||
nd4j::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CUDA maxpool3dnew;
|
||||
nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CUDA maxpool3dnew_bp;
|
||||
|
||||
|
||||
|
||||
printer({&conv2d});
|
||||
printer({&conv2d_bp});
|
||||
|
@ -65,6 +76,15 @@ TEST_F(CuDnnTests, helpers_includer) {
|
|||
printer({&depthwise_conv2d});
|
||||
printer({&depthwise_conv2d_bp});
|
||||
printer({&batchnorm});
|
||||
printer({&batchnorm_bp});
|
||||
printer({&avgpool2d});
|
||||
printer({&avgpool2d_bp});
|
||||
printer({&maxpool2d});
|
||||
printer({&maxpool2d_bp});
|
||||
printer({&avgpool3dnew});
|
||||
printer({&avgpool3dnew_bp});
|
||||
printer({&maxpool3dnew});
|
||||
printer({&maxpool3dnew_bp});
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <ops/ops.h>
|
||||
#include <GradCheck.h>
|
||||
#include <memory>
|
||||
#include <PointersManager.h>
|
||||
|
||||
using namespace nd4j;
|
||||
|
||||
|
@ -2247,3 +2248,525 @@ TEST_F(DeclarableOpsTests13, batchnorm_test9) {
|
|||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) {
|
||||
|
||||
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
variance.assign(0.46666667);
|
||||
gamma.assign(1.2);
|
||||
beta.assign(1.); // has no effect on gradient calculations
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, batchnorm_bp_test2) {
|
||||
|
||||
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {3}, {0.5, 0.6, 0.7}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {3}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747,
|
||||
0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978,
|
||||
-0.290863, -0.343746, -0.396631}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
// beta.assign(1.); // has no effect on gradient calculations
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, batchnorm_bp_test3) {
|
||||
|
||||
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {2,1,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002,
|
||||
0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000,
|
||||
-0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
// beta.assign(1.); // has no effect on gradient calculations
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, batchnorm_bp_test4) {
|
||||
|
||||
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, batchnorm_bp_test5) {
|
||||
|
||||
#if defined(HAVE_CUDNN)
|
||||
return;
|
||||
#endif
|
||||
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243,
|
||||
-1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118,
|
||||
-0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, batchnorm_bp_test6) {
|
||||
|
||||
#if defined(HAVE_CUDNN)
|
||||
return;
|
||||
#endif
|
||||
|
||||
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295,
|
||||
0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295,
|
||||
-0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, batchnorm_bp_test7) {
|
||||
|
||||
#if defined(HAVE_CUDNN)
|
||||
return;
|
||||
#endif
|
||||
|
||||
NDArray input ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142,
|
||||
-43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662,
|
||||
15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032,
|
||||
-15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788,
|
||||
-27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
// dLdI->printBuffer();
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, batchnorm_bp_test8) {
|
||||
|
||||
#if defined(HAVE_CUDNN)
|
||||
return;
|
||||
#endif
|
||||
|
||||
NDArray input ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301,
|
||||
32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767,
|
||||
-27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526,
|
||||
30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773,
|
||||
31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
// dLdI->printBuffer();
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, batchnorm_bp_test9) {
|
||||
|
||||
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,4,2,2}, {0.032378, 0.028967, 0.025558, 0.022147, -0.035056, -0.031364, -0.027669, -0.024006, 0.037742, 0.033766, 0.029791, 0.025818,
|
||||
-0.040429, -0.036172, -0.031913, -0.027656, -0.022155, -0.025564, -0.028974, -0.032359, 0.023982, 0.027677, 0.031373, 0.035063,
|
||||
-0.025822, -0.029794, -0.033770, -0.037747, 0.027653, 0.031913, 0.036168, 0.040426}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {9.685875, 9.685880, 9.685887, 9.685891}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(1,0.01);
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
// calculate mean and variance of input
|
||||
PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9");
|
||||
std::vector<int> dimensions = {0,2,3};
|
||||
int* dims = reinterpret_cast<int*>(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)));
|
||||
input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions);
|
||||
NDArray::prepareSpecialUse({&variance}, {&input});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions);
|
||||
NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false);
|
||||
manager.synchronize();
|
||||
NDArray::registerSpecialUse({&variance}, {&input});
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, batchnorm_bp_test10) {
|
||||
|
||||
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,2,2,4}, {0.032634, -0.035423, 0.038110, -0.040864, 0.023302, -0.025294, 0.027213, -0.029205, 0.013996, -0.015192, 0.016343,
|
||||
-0.017519, 0.004664, -0.005062, 0.005445, -0.005833, -0.004668, 0.005067, -0.005452, 0.005824, -0.013974, 0.015171,
|
||||
-0.016325, 0.017508, -0.023309, 0.025301, -0.027221, 0.029197, -0.032639, 0.035428, -0.038118, 0.040878}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {10.991656, 10.991631, 10.991643, 10.991632}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {9.6, 10.8, 12., 13.2}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(1,0.01);
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
// calculate mean and variance of input
|
||||
PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9");
|
||||
std::vector<int> dimensions = {0,1,2};
|
||||
int* dims = reinterpret_cast<int*>(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)));
|
||||
input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions);
|
||||
NDArray::prepareSpecialUse({&variance}, {&input});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions);
|
||||
NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false);
|
||||
manager.synchronize();
|
||||
NDArray::registerSpecialUse({&variance}, {&input});
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, batchnorm_bp_test11) {
|
||||
|
||||
NDArray input ('c', {2,3,4,5}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {1,3,4,5}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {1,3,4,5}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {1,3,4,5}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {1,3,4,5}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,3,4,5}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,3,4,5}, {0.004981, 0.004818, 0.004652, 0.004483, 0.004319, 0.004153, 0.003985, 0.003832, 0.003661, 0.003505, 0.003340, 0.003171, 0.003001, 0.002837,
|
||||
0.002670, 0.002505, 0.002337, 0.002167, 0.002003, 0.001835, 0.001666, 0.001499, 0.001327, 0.001162, 0.000996, 0.000830, 0.000664, 0.000498,
|
||||
0.000332, 0.000166, -0.0, -0.000166, -0.000333, -0.000500, -0.000668, -0.000835, -0.001003, -0.001168, -0.001337, -0.001502, -0.001670,
|
||||
-0.001838, -0.002003, -0.002172, -0.002330, -0.002499, -0.002669, -0.002832, -0.003002, -0.003162, -0.003332, -0.003495, -0.003665, -0.003821,
|
||||
-0.004001, -0.004163, -0.004324, -0.004516, -0.004678, -0.004851, -0.004981, -0.004818, -0.004652, -0.004483, -0.004319, -0.004151, -0.003985,
|
||||
-0.003836, -0.003661, -0.003505, -0.003338, -0.003171, -0.003004, -0.002837, -0.002670, -0.002503, -0.002337, -0.002170, -0.002003, -0.001835,
|
||||
-0.001664, -0.001499, -0.001328, -0.001162, -0.000996, -0.000829, -0.000664, -0.000498, -0.000332, -0.000166, 0.0, 0.000166, 0.000334,
|
||||
0.000500, 0.000668, 0.000834, 0.001003, 0.001170, 0.001337, 0.001502, 0.001669, 0.001838, 0.002005, 0.002172, 0.002330, 0.002496, 0.002669,
|
||||
0.002836, 0.003002, 0.003162, 0.003328, 0.003495, 0.003670, 0.003828, 0.003992, 0.004158, 0.004324, 0.004522, 0.004689, 0.004843}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {1,3,4,5}, {8.999503, 8.999502, 8.999502, 8.999503, 8.999502, 8.999503, 8.999503, 8.999499, 8.999501, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498,
|
||||
8.999498, 8.999498, 8.999498, 8.999498, 8.999499, 8.999501, 8.999500, 8.999503, 8.999503, 8.999503, 8.999504, 8.999503, 8.999503, 8.999504, 8.999503,
|
||||
8.999504, 8.999504, 8.999499, 8.999500, 8.999497, 8.999498, 8.999496, 8.999496, 8.999496, 8.999498, 8.999498, 8.999496, 8.999496, 8.999496, 8.999501,
|
||||
8.999501, 8.999499, 8.999499, 8.999499, 8.999501, 8.999501, 8.999501, 8.999499, 8.999500, 8.999501, 8.999501, 8.999501, 8.999495, 8.999495, 8.999497}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {1,3,4,5}, {7.2, 7.5, 7.8, 8.1, 8.4, 8.7, 9.0, 9.3, 9.6, 9.9, 10.2, 10.5, 10.8, 11.1, 11.4, 11.7, 12.0, 12.3, 12.6, 12.9, 13.2, 13.5, 13.8, 14.1, 14.4, 14.7, 15.0,
|
||||
15.3, 15.6, 15.9, 16.2, 16.5, 16.8, 17.1, 17.4, 17.7, 18.0, 18.3, 18.6, 18.9, 19.2, 19.5, 19.8, 20.1, 20.4, 20.7, 21.0, 21.3, 21.6, 21.9, 22.2, 22.5,
|
||||
22.8, 23.1, 23.4, 23.7, 24.0, 24.3, 24.6, 24.9}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(1,0.01);
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
gamma.linspace(-3, 0.1);
|
||||
|
||||
// calculate mean and variance of input
|
||||
PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9");
|
||||
std::vector<int> dimensions = {0};
|
||||
int* dims = reinterpret_cast<int*>(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)));
|
||||
input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions, true);
|
||||
NDArray::prepareSpecialUse({&variance}, {&input});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions);
|
||||
NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false);
|
||||
manager.synchronize();
|
||||
NDArray::registerSpecialUse({&variance}, {&input});
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1, 1,2,3});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
|
@ -72,71 +72,6 @@ TEST_F(DeclarableOpsTests15, Test_Half_assign_1) {
|
|||
ASSERT_EQ(10, x.sumNumber().e<int>(0));
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_avgpooling_edge_1) {
|
||||
int inOutH = 5;// 35;
|
||||
int inOutW = 5;// 35;
|
||||
int inOutC = 10;// 192;
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {1, inOutH, inOutW, inOutC});
|
||||
x.linspace(1.0);
|
||||
|
||||
nd4j::ops::avgpool2d op;
|
||||
auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH;
|
||||
int padTop = totalPadHeight / 2;
|
||||
int padBottom = totalPadHeight - totalPadHeight / 2;
|
||||
|
||||
int k = 3;
|
||||
|
||||
auto m = NDArrayFactory::create<double>('c', {1, inOutH, inOutW, inOutC});
|
||||
auto c = NDArrayFactory::create<double>('c', {1, inOutH, inOutW, inOutC});
|
||||
|
||||
for (int h = 0; h < inOutH; h++) {
|
||||
for (int w = 0; w < inOutW; w++) {
|
||||
int hFrom = h - padTop;
|
||||
int wFrom = w - padBottom;
|
||||
|
||||
int hTo = hFrom + k;
|
||||
int wTo = wFrom + k;
|
||||
|
||||
hFrom = nd4j::math::nd4j_max<int>(0, hFrom);
|
||||
wFrom = nd4j::math::nd4j_max<int>(0, wFrom);
|
||||
|
||||
hTo = nd4j::math::nd4j_min<int>(inOutH, hTo);
|
||||
wTo = nd4j::math::nd4j_min<int>(inOutW, wTo);
|
||||
|
||||
int idxOut[4];
|
||||
int idxIn[4];
|
||||
for (int ch = 0; ch < inOutC; ch++) {
|
||||
idxOut[1] = h;
|
||||
idxOut[2] = w;
|
||||
idxOut[3] = ch;
|
||||
idxIn[3] = ch;
|
||||
|
||||
for (int kh = hFrom; kh < hTo; kh++) {
|
||||
for (int kw = wFrom; kw < wTo; kw++) {
|
||||
idxIn[1] = kh;
|
||||
idxIn[2] = kw;
|
||||
|
||||
auto inVal = x.e<double>(0, kh, kw, ch);
|
||||
m.p(0, h, w, ch, inVal + m.e<double>(0, h, w, ch));
|
||||
c.p(0, h, w, ch, 1 + c.e<int>(0, h, w, ch));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
m /= c;
|
||||
|
||||
ASSERT_EQ(m, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_standarize_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f});
|
||||
auto e = NDArrayFactory::create<float>('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f});
|
||||
|
@ -1097,7 +1032,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_2) {
|
|||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
@ -1106,7 +1041,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_3) {
|
|||
// rank 2
|
||||
NDArray rgbs('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
||||
nd4j::ops::rgb_to_yuv op;
|
||||
auto result = op.execute({ &rgbs }, {}, { 0 });
|
||||
auto output = result->at(0);
|
||||
|
@ -1170,7 +1105,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_7) {
|
|||
// rank 3
|
||||
NDArray rgbs('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, nd4j::DataType::FLOAT32);
|
||||
NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
||||
nd4j::ops::rgb_to_yuv op;
|
||||
auto result = op.execute({ &rgbs }, {}, {});
|
||||
auto output = result->at(0);
|
||||
|
@ -1210,7 +1145,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_2) {
|
|||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
@ -1484,7 +1419,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test7) {
|
|||
auto Y = NDArrayFactory::create<float>(2.f);
|
||||
NDArray x('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdzC('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
||||
dLdzC.linspace(0.1, 0.1);
|
||||
x = 4.f;
|
||||
|
||||
|
|
|
@ -883,22 +883,6 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_6) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_AvgPool_1) {
|
||||
auto x= NDArrayFactory::create<float>('c', {2, 10, 10, 3});
|
||||
x.linspace(1);
|
||||
|
||||
nd4j::ops::avgpool2d op;
|
||||
// kY kX sY sX pY pX dY dX M P
|
||||
auto result = op.execute({&x}, {}, {3, 3, 3, 3, 0, 0, 1, 1, 1, 0, 1});
|
||||
// 0 1 2 3 4 5 6 7 8 9 10
|
||||
auto z = result->at(0);
|
||||
|
||||
// z->printShapeInfo("z shape");
|
||||
// z->printIndexedBuffer("z buffr");
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) {
|
||||
auto x= NDArrayFactory::create<double>('c', {1, 3}, {2, 2, 2});
|
||||
auto y= NDArrayFactory::create<double>('c', {1, 3}, {4, 6, 8});
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -2459,42 +2459,6 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test4) {
|
|||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests8, avgpool2d_test13) {
|
||||
|
||||
int bS=4, iH=10,iW=10, iC=3, kH=3,kW=3, sH=3,sW=3, pH=0,pW=0, dH=1,dW=1;
|
||||
int oH=4, oW=4;
|
||||
int paddingMode = 1; // 1-SAME, 0-VALID
|
||||
int dataFormat = 1; // 1-NHWC, 0-NDHW
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
|
||||
auto expected = NDArrayFactory::create<double>('c', {bS, oH, oW, iC}, { 17.5, 18.5, 19.5, 25. , 26. , 27. , 34. , 35. , 36. , 41.5, 42.5, 43.5, 92.5, 93.5, 94.5, 100. , 101. , 102. , 109. , 110. , 111. , 116.5, 117.5, 118.5,
|
||||
182.5, 183.5, 184.5, 190. , 191. , 192. , 199. , 200. , 201. , 206.5, 207.5, 208.5, 257.5, 258.5, 259.5, 265. , 266. , 267. , 274. , 275. , 276. , 281.5, 282.5, 283.5,
|
||||
317.5, 318.5, 319.5, 325. , 326. , 327. , 334. , 335. , 336. , 341.5, 342.5, 343.5, 392.5, 393.5, 394.5, 400. , 401. , 402. , 409. , 410. , 411. , 416.5, 417.5, 418.5,
|
||||
482.5, 483.5, 484.5, 490. , 491. , 492. , 499. , 500. , 501. , 506.5, 507.5, 508.5, 557.5, 558.5, 559.5, 565. , 566. , 567. , 574. , 575. , 576. , 581.5, 582.5, 583.5,
|
||||
617.5, 618.5, 619.5, 625. , 626. , 627. , 634. , 635. , 636. , 641.5, 642.5, 643.5, 692.5, 693.5, 694.5, 700. , 701. , 702. , 709. , 710. , 711. , 716.5, 717.5, 718.5,
|
||||
782.5, 783.5, 784.5, 790. , 791. , 792. , 799. , 800. , 801. , 806.5, 807.5, 808.5, 857.5, 858.5, 859.5, 865. , 866. , 867. , 874. , 875. , 876. , 881.5, 882.5, 883.5,
|
||||
917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5,
|
||||
1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5});
|
||||
input.linspace(1.);
|
||||
input.syncToDevice();
|
||||
|
||||
nd4j::ops::avgpool2d op;
|
||||
auto results = op.execute({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
|
||||
//output->printIndexedBuffer("output");
|
||||
//expected.printIndexedBuffer("expected");
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test1) {
|
||||
|
||||
|
|
|
@ -2894,344 +2894,7 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) {
|
||||
|
||||
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
variance.assign(0.46666667);
|
||||
gamma.assign(1.2);
|
||||
beta.assign(1.); // has no effect on gradient calculations
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) {
|
||||
|
||||
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {3}, {0.5, 0.6, 0.7}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {3}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747,
|
||||
0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978,
|
||||
-0.290863, -0.343746, -0.396631}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
// beta.assign(1.); // has no effect on gradient calculations
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) {
|
||||
|
||||
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {2,1,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002,
|
||||
0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000,
|
||||
-0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
// beta.assign(1.); // has no effect on gradient calculations
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) {
|
||||
|
||||
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) {
|
||||
|
||||
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243,
|
||||
-1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118,
|
||||
-0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) {
|
||||
|
||||
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295,
|
||||
0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295,
|
||||
-0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) {
|
||||
|
||||
NDArray input ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142,
|
||||
-43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662,
|
||||
15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032,
|
||||
-15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788,
|
||||
-27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
// dLdI->printBuffer();
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) {
|
||||
|
||||
NDArray input ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301,
|
||||
32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767,
|
||||
-27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526,
|
||||
30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773,
|
||||
31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto dLdI = results->at(0);
|
||||
auto dLdG = results->at(3);
|
||||
auto dLdB = results->at(4);
|
||||
|
||||
// dLdI->printBuffer();
|
||||
|
||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||
|
||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||
|
||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
/*
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
|
||||
|
|
Loading…
Reference in New Issue